mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[client] Add IPv6 support to ACL manager, USP filter, and forwarder (#5688)
This commit is contained in:
@@ -238,43 +238,84 @@ func (c *Client) Networks() *NetworkArray {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
routesMap := routeManager.GetClientRoutesWithNetID()
|
||||||
|
v6Merged := route.V6ExitMergeSet(routesMap)
|
||||||
|
resolvedDomains := c.recorder.GetResolvedDomainsStates()
|
||||||
|
|
||||||
networkArray := &NetworkArray{
|
networkArray := &NetworkArray{
|
||||||
items: make([]Network, 0),
|
items: make([]Network, 0),
|
||||||
}
|
}
|
||||||
|
|
||||||
resolvedDomains := c.recorder.GetResolvedDomainsStates()
|
for id, routes := range routesMap {
|
||||||
|
|
||||||
for id, routes := range routeManager.GetClientRoutesWithNetID() {
|
|
||||||
if len(routes) == 0 {
|
if len(routes) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if _, skip := v6Merged[id]; skip {
|
||||||
r := routes[0]
|
|
||||||
domains := c.getNetworkDomainsFromRoute(r, resolvedDomains)
|
|
||||||
netStr := r.Network.String()
|
|
||||||
|
|
||||||
if r.IsDynamic() {
|
|
||||||
netStr = r.Domains.SafeString()
|
|
||||||
}
|
|
||||||
|
|
||||||
routePeer, err := c.recorder.GetPeer(routes[0].Peer)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
network := Network{
|
|
||||||
Name: string(id),
|
network := c.buildNetwork(id, routes, routeSelector.IsSelected(id), resolvedDomains, v6Merged)
|
||||||
Network: netStr,
|
if network == nil {
|
||||||
Peer: routePeer.FQDN,
|
continue
|
||||||
Status: routePeer.ConnStatus.String(),
|
|
||||||
IsSelected: routeSelector.IsSelected(id),
|
|
||||||
Domains: domains,
|
|
||||||
}
|
}
|
||||||
networkArray.Add(network)
|
networkArray.Add(*network)
|
||||||
}
|
}
|
||||||
return networkArray
|
return networkArray
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) buildNetwork(id route.NetID, routes []*route.Route, selected bool, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo, v6Merged map[route.NetID]struct{}) *Network {
|
||||||
|
r := routes[0]
|
||||||
|
netStr := r.Network.String()
|
||||||
|
if r.IsDynamic() {
|
||||||
|
netStr = r.Domains.SafeString()
|
||||||
|
}
|
||||||
|
|
||||||
|
routePeer, err := c.findBestRoutePeer(routes)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("could not get peer info for route %s: %v", id, err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
network := &Network{
|
||||||
|
Name: string(id),
|
||||||
|
Network: netStr,
|
||||||
|
Peer: routePeer.FQDN,
|
||||||
|
Status: routePeer.ConnStatus.String(),
|
||||||
|
IsSelected: selected,
|
||||||
|
Domains: c.getNetworkDomainsFromRoute(r, resolvedDomains),
|
||||||
|
}
|
||||||
|
|
||||||
|
if route.IsV4DefaultRoute(r.Network) && route.HasV6ExitPair(id, v6Merged) {
|
||||||
|
network.Network = "0.0.0.0/0, ::/0"
|
||||||
|
}
|
||||||
|
|
||||||
|
return network
|
||||||
|
}
|
||||||
|
|
||||||
|
// findBestRoutePeer returns the peer actively routing traffic for the given
|
||||||
|
// HA route group. Falls back to the first connected peer, then the first peer.
|
||||||
|
func (c *Client) findBestRoutePeer(routes []*route.Route) (peer.State, error) {
|
||||||
|
netStr := routes[0].Network.String()
|
||||||
|
|
||||||
|
fullStatus := c.recorder.GetFullStatus()
|
||||||
|
for _, p := range fullStatus.Peers {
|
||||||
|
if _, ok := p.GetRoutes()[netStr]; ok {
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range routes {
|
||||||
|
p, err := c.recorder.GetPeer(r.Peer)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if p.ConnStatus == peer.StatusConnected {
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return c.recorder.GetPeer(routes[0].Peer)
|
||||||
|
}
|
||||||
|
|
||||||
// OnUpdatedHostDNS update the DNS servers addresses for root zones
|
// OnUpdatedHostDNS update the DNS servers addresses for root zones
|
||||||
func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
|
func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
|
||||||
dnsServer, err := dns.GetServerDns()
|
dnsServer, err := dns.GetServerDns()
|
||||||
|
|||||||
@@ -18,9 +18,12 @@ func executeRouteToggle(id string, manager routemanager.Manager,
|
|||||||
netID := route.NetID(id)
|
netID := route.NetID(id)
|
||||||
routes := []route.NetID{netID}
|
routes := []route.NetID{netID}
|
||||||
|
|
||||||
log.Debugf("%s with id: %s", operationName, id)
|
routesMap := manager.GetClientRoutesWithNetID()
|
||||||
|
routes = route.ExpandV6ExitPairs(routes, routesMap)
|
||||||
|
|
||||||
if err := routeOperation(routes, maps.Keys(manager.GetClientRoutesWithNetID())); err != nil {
|
log.Debugf("%s with ids: %v", operationName, routes)
|
||||||
|
|
||||||
|
if err := routeOperation(routes, maps.Keys(routesMap)); err != nil {
|
||||||
log.Debugf("error when %s: %s", operationName, err)
|
log.Debugf("error when %s: %s", operationName, err)
|
||||||
return fmt.Errorf("error %s: %w", operationName, err)
|
return fmt.Errorf("error %s: %w", operationName, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"regexp"
|
"regexp"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -26,8 +27,9 @@ type Anonymizer struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func DefaultAddresses() (netip.Addr, netip.Addr) {
|
func DefaultAddresses() (netip.Addr, netip.Addr) {
|
||||||
// 198.51.100.0, 100::
|
// 198.51.100.0 (RFC 5737 TEST-NET-2), 2001:db8:ffff:: (RFC 3849 documentation, last /48)
|
||||||
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01})
|
// The old start 100:: (discard, RFC 6666) is now used for fake IPs on Android.
|
||||||
|
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.MustParseAddr("2001:db8:ffff::")
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer {
|
func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer {
|
||||||
@@ -96,6 +98,11 @@ func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Anonymizer) AnonymizeIPString(ip string) string {
|
func (a *Anonymizer) AnonymizeIPString(ip string) string {
|
||||||
|
// Handle CIDR notation (e.g. "2001:db8::/32")
|
||||||
|
if prefix, err := netip.ParsePrefix(ip); err == nil {
|
||||||
|
return a.AnonymizeIP(prefix.Addr()).String() + "/" + strconv.Itoa(prefix.Bits())
|
||||||
|
}
|
||||||
|
|
||||||
addr, err := netip.ParseAddr(ip)
|
addr, err := netip.ParseAddr(ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ip
|
return ip
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
|
|
||||||
func TestAnonymizeIP(t *testing.T) {
|
func TestAnonymizeIP(t *testing.T) {
|
||||||
startIPv4 := netip.MustParseAddr("198.51.100.0")
|
startIPv4 := netip.MustParseAddr("198.51.100.0")
|
||||||
startIPv6 := netip.MustParseAddr("100::")
|
startIPv6 := netip.MustParseAddr("2001:db8:ffff::")
|
||||||
anonymizer := anonymize.NewAnonymizer(startIPv4, startIPv6)
|
anonymizer := anonymize.NewAnonymizer(startIPv4, startIPv6)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
@@ -26,9 +26,9 @@ func TestAnonymizeIP(t *testing.T) {
|
|||||||
{"Second Public IPv4", "4.3.2.1", "198.51.100.1"},
|
{"Second Public IPv4", "4.3.2.1", "198.51.100.1"},
|
||||||
{"Repeated IPv4", "1.2.3.4", "198.51.100.0"},
|
{"Repeated IPv4", "1.2.3.4", "198.51.100.0"},
|
||||||
{"Private IPv4", "192.168.1.1", "192.168.1.1"},
|
{"Private IPv4", "192.168.1.1", "192.168.1.1"},
|
||||||
{"First Public IPv6", "2607:f8b0:4005:805::200e", "100::"},
|
{"First Public IPv6", "2607:f8b0:4005:805::200e", "2001:db8:ffff::"},
|
||||||
{"Second Public IPv6", "a::b", "100::1"},
|
{"Second Public IPv6", "a::b", "2001:db8:ffff::1"},
|
||||||
{"Repeated IPv6", "2607:f8b0:4005:805::200e", "100::"},
|
{"Repeated IPv6", "2607:f8b0:4005:805::200e", "2001:db8:ffff::"},
|
||||||
{"Private IPv6", "fe80::1", "fe80::1"},
|
{"Private IPv6", "fe80::1", "fe80::1"},
|
||||||
{"In Range IPv4", "198.51.100.2", "198.51.100.2"},
|
{"In Range IPv4", "198.51.100.2", "198.51.100.2"},
|
||||||
}
|
}
|
||||||
@@ -274,17 +274,17 @@ func TestAnonymizeString_IPAddresses(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "IPv6 Address",
|
name: "IPv6 Address",
|
||||||
input: "Access attempted from 2001:db8::ff00:42",
|
input: "Access attempted from 2001:db8::ff00:42",
|
||||||
expect: "Access attempted from 100::",
|
expect: "Access attempted from 2001:db8:ffff::",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "IPv6 Address with Port",
|
name: "IPv6 Address with Port",
|
||||||
input: "Access attempted from [2001:db8::ff00:42]:8080",
|
input: "Access attempted from [2001:db8::ff00:42]:8080",
|
||||||
expect: "Access attempted from [100::]:8080",
|
expect: "Access attempted from [2001:db8:ffff::]:8080",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Both IPv4 and IPv6",
|
name: "Both IPv4 and IPv6",
|
||||||
input: "IPv4: 142.108.0.1 and IPv6: 2001:db8::ff00:43",
|
input: "IPv4: 142.108.0.1 and IPv6: 2001:db8::ff00:43",
|
||||||
expect: "IPv4: 198.51.100.1 and IPv6: 100::1",
|
expect: "IPv4: 198.51.100.1 and IPv6: 2001:db8:ffff::1",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -787,10 +787,10 @@ func isUnixSocket(path string) bool {
|
|||||||
return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./")
|
return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./")
|
||||||
}
|
}
|
||||||
|
|
||||||
// normalizeLocalHost converts "*" to "0.0.0.0" for binding to all interfaces.
|
// normalizeLocalHost converts "*" to "" for binding to all interfaces (dual-stack).
|
||||||
func normalizeLocalHost(host string) string {
|
func normalizeLocalHost(host string) string {
|
||||||
if host == "*" {
|
if host == "*" {
|
||||||
return "0.0.0.0"
|
return ""
|
||||||
}
|
}
|
||||||
return host
|
return host
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -527,10 +527,10 @@ func TestParsePortForward(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "wildcard bind all interfaces",
|
name: "wildcard bind all interfaces",
|
||||||
spec: "*:8080:localhost:80",
|
spec: "*:8080:localhost:80",
|
||||||
expectedLocal: "0.0.0.0:8080",
|
expectedLocal: ":8080",
|
||||||
expectedRemote: "localhost:80",
|
expectedRemote: "localhost:80",
|
||||||
expectError: false,
|
expectError: false,
|
||||||
description: "Wildcard * should bind to all interfaces (0.0.0.0)",
|
description: "Wildcard * should bind to all interfaces (dual-stack)",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "wildcard for port only",
|
name: "wildcard for port only",
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ type aclManager struct {
|
|||||||
entries aclEntries
|
entries aclEntries
|
||||||
optionalEntries map[string][]entry
|
optionalEntries map[string][]entry
|
||||||
ipsetStore *ipsetStore
|
ipsetStore *ipsetStore
|
||||||
|
v6 bool
|
||||||
|
|
||||||
stateManager *statemanager.Manager
|
stateManager *statemanager.Manager
|
||||||
}
|
}
|
||||||
@@ -47,6 +48,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*acl
|
|||||||
entries: make(map[string][][]string),
|
entries: make(map[string][][]string),
|
||||||
optionalEntries: make(map[string][]entry),
|
optionalEntries: make(map[string][]entry),
|
||||||
ipsetStore: newIpsetStore(),
|
ipsetStore: newIpsetStore(),
|
||||||
|
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -81,7 +83,11 @@ func (m *aclManager) AddPeerFiltering(
|
|||||||
chain := chainNameInputRules
|
chain := chainNameInputRules
|
||||||
|
|
||||||
ipsetName = transformIPsetName(ipsetName, sPort, dPort, action)
|
ipsetName = transformIPsetName(ipsetName, sPort, dPort, action)
|
||||||
specs := filterRuleSpecs(ip, string(protocol), sPort, dPort, action, ipsetName)
|
if m.v6 && ipsetName != "" {
|
||||||
|
ipsetName += "-v6"
|
||||||
|
}
|
||||||
|
proto := protoForFamily(protocol, m.v6)
|
||||||
|
specs := filterRuleSpecs(ip, proto, sPort, dPort, action, ipsetName)
|
||||||
|
|
||||||
mangleSpecs := slices.Clone(specs)
|
mangleSpecs := slices.Clone(specs)
|
||||||
mangleSpecs = append(mangleSpecs,
|
mangleSpecs = append(mangleSpecs,
|
||||||
@@ -105,6 +111,7 @@ func (m *aclManager) AddPeerFiltering(
|
|||||||
ip: ip.String(),
|
ip: ip.String(),
|
||||||
chain: chain,
|
chain: chain,
|
||||||
specs: specs,
|
specs: specs,
|
||||||
|
v6: m.v6,
|
||||||
}}, nil
|
}}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -157,6 +164,7 @@ func (m *aclManager) AddPeerFiltering(
|
|||||||
ipsetName: ipsetName,
|
ipsetName: ipsetName,
|
||||||
ip: ip.String(),
|
ip: ip.String(),
|
||||||
chain: chain,
|
chain: chain,
|
||||||
|
v6: m.v6,
|
||||||
}
|
}
|
||||||
|
|
||||||
m.updateState()
|
m.updateState()
|
||||||
@@ -376,8 +384,13 @@ func (m *aclManager) updateState() {
|
|||||||
currentState.Lock()
|
currentState.Lock()
|
||||||
defer currentState.Unlock()
|
defer currentState.Unlock()
|
||||||
|
|
||||||
currentState.ACLEntries = m.entries
|
if m.v6 {
|
||||||
currentState.ACLIPsetStore = m.ipsetStore
|
currentState.ACLEntries6 = m.entries
|
||||||
|
currentState.ACLIPsetStore6 = m.ipsetStore
|
||||||
|
} else {
|
||||||
|
currentState.ACLEntries = m.entries
|
||||||
|
currentState.ACLIPsetStore = m.ipsetStore
|
||||||
|
}
|
||||||
|
|
||||||
if err := m.stateManager.UpdateState(currentState); err != nil {
|
if err := m.stateManager.UpdateState(currentState); err != nil {
|
||||||
log.Errorf("failed to update state: %v", err)
|
log.Errorf("failed to update state: %v", err)
|
||||||
@@ -385,6 +398,15 @@ func (m *aclManager) updateState() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// filterRuleSpecs returns the specs of a filtering rule
|
// filterRuleSpecs returns the specs of a filtering rule
|
||||||
|
// protoForFamily translates ICMP to ICMPv6 for ip6tables.
|
||||||
|
// ip6tables requires "ipv6-icmp" (or "icmpv6") instead of "icmp".
|
||||||
|
func protoForFamily(protocol firewall.Protocol, v6 bool) string {
|
||||||
|
if v6 && protocol == firewall.ProtocolICMP {
|
||||||
|
return "ipv6-icmp"
|
||||||
|
}
|
||||||
|
return string(protocol)
|
||||||
|
}
|
||||||
|
|
||||||
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
|
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
|
||||||
// don't use IP matching if IP is 0.0.0.0
|
// don't use IP matching if IP is 0.0.0.0
|
||||||
matchByIP := !ip.IsUnspecified()
|
matchByIP := !ip.IsUnspecified()
|
||||||
@@ -437,6 +459,9 @@ func (m *aclManager) createIPSet(name string) error {
|
|||||||
opts := ipset.CreateOptions{
|
opts := ipset.CreateOptions{
|
||||||
Replace: true,
|
Replace: true,
|
||||||
}
|
}
|
||||||
|
if m.v6 {
|
||||||
|
opts.Family = ipset.FamilyIPV6
|
||||||
|
}
|
||||||
|
|
||||||
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
||||||
return fmt.Errorf("create ipset %s: %w", name, err)
|
return fmt.Errorf("create ipset %s: %w", name, err)
|
||||||
|
|||||||
@@ -17,6 +17,10 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type resetter interface {
|
||||||
|
Reset() error
|
||||||
|
}
|
||||||
|
|
||||||
// Manager of iptables firewall
|
// Manager of iptables firewall
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
@@ -27,6 +31,11 @@ type Manager struct {
|
|||||||
aclMgr *aclManager
|
aclMgr *aclManager
|
||||||
router *router
|
router *router
|
||||||
rawSupported bool
|
rawSupported bool
|
||||||
|
|
||||||
|
// IPv6 counterparts, nil when no v6 overlay
|
||||||
|
ipv6Client *iptables.IPTables
|
||||||
|
aclMgr6 *aclManager
|
||||||
|
router6 *router
|
||||||
}
|
}
|
||||||
|
|
||||||
// iFaceMapper defines subset methods of interface required for manager
|
// iFaceMapper defines subset methods of interface required for manager
|
||||||
@@ -58,9 +67,43 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
|||||||
return nil, fmt.Errorf("create acl manager: %w", err)
|
return nil, fmt.Errorf("create acl manager: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if wgIface.Address().HasIPv6() {
|
||||||
|
if err := m.createIPv6Components(wgIface, mtu); err != nil {
|
||||||
|
return nil, fmt.Errorf("create IPv6 firewall: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) createIPv6Components(wgIface iFaceMapper, mtu uint16) error {
|
||||||
|
ip6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("init ip6tables: %w", err)
|
||||||
|
}
|
||||||
|
m.ipv6Client = ip6Client
|
||||||
|
|
||||||
|
m.router6, err = newRouter(ip6Client, wgIface, mtu)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create v6 router: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Share the same IP forwarding state with the v4 router, since
|
||||||
|
// EnableIPForwarding controls both v4 and v6 sysctls.
|
||||||
|
m.router6.ipFwdState = m.router.ipFwdState
|
||||||
|
|
||||||
|
m.aclMgr6, err = newAclManager(ip6Client, wgIface)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create v6 acl manager: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) hasIPv6() bool {
|
||||||
|
return m.ipv6Client != nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||||
state := &ShutdownState{
|
state := &ShutdownState{
|
||||||
InterfaceState: &InterfaceState{
|
InterfaceState: &InterfaceState{
|
||||||
@@ -75,13 +118,8 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
log.Errorf("failed to update state: %v", err)
|
log.Errorf("failed to update state: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.router.init(stateManager); err != nil {
|
if err := m.initChains(stateManager); err != nil {
|
||||||
return fmt.Errorf("router init: %w", err)
|
return err
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.aclMgr.init(stateManager); err != nil {
|
|
||||||
// TODO: cleanup router
|
|
||||||
return fmt.Errorf("acl manager init: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.initNoTrackChain(); err != nil {
|
if err := m.initNoTrackChain(); err != nil {
|
||||||
@@ -98,6 +136,41 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// initChains initializes router and ACL chains for both address families,
|
||||||
|
// rolling back on failure.
|
||||||
|
func (m *Manager) initChains(stateManager *statemanager.Manager) error {
|
||||||
|
type initStep struct {
|
||||||
|
name string
|
||||||
|
init func(*statemanager.Manager) error
|
||||||
|
mgr resetter
|
||||||
|
}
|
||||||
|
|
||||||
|
steps := []initStep{
|
||||||
|
{"router", m.router.init, m.router},
|
||||||
|
{"acl manager", m.aclMgr.init, m.aclMgr},
|
||||||
|
}
|
||||||
|
if m.hasIPv6() {
|
||||||
|
steps = append(steps,
|
||||||
|
initStep{"v6 router", m.router6.init, m.router6},
|
||||||
|
initStep{"v6 acl manager", m.aclMgr6.init, m.aclMgr6},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
var initialized []initStep
|
||||||
|
for _, s := range steps {
|
||||||
|
if err := s.init(stateManager); err != nil {
|
||||||
|
for i := len(initialized) - 1; i >= 0; i-- {
|
||||||
|
if rerr := initialized[i].mgr.Reset(); rerr != nil {
|
||||||
|
log.Warnf("rollback %s: %v", initialized[i].name, rerr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fmt.Errorf("%s init: %w", s.name, err)
|
||||||
|
}
|
||||||
|
initialized = append(initialized, s)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// AddPeerFiltering adds a rule to the firewall
|
// AddPeerFiltering adds a rule to the firewall
|
||||||
//
|
//
|
||||||
// Comment will be ignored because some system this feature is not supported
|
// Comment will be ignored because some system this feature is not supported
|
||||||
@@ -113,7 +186,13 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
if ip.To4() != nil {
|
||||||
|
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||||
|
}
|
||||||
|
if !m.hasIPv6() {
|
||||||
|
return nil, fmt.Errorf("IPv6 not initialized, cannot add rule for %s", ip)
|
||||||
|
}
|
||||||
|
return m.aclMgr6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddRouteFiltering(
|
func (m *Manager) AddRouteFiltering(
|
||||||
@@ -127,25 +206,48 @@ func (m *Manager) AddRouteFiltering(
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
|
if isIPv6RouteRule(sources, destination) {
|
||||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
|
if !m.hasIPv6() {
|
||||||
|
return nil, fmt.Errorf("IPv6 not initialized, cannot add route rule")
|
||||||
|
}
|
||||||
|
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
|
||||||
|
if destination.IsPrefix() {
|
||||||
|
return destination.Prefix.Addr().Is6()
|
||||||
|
}
|
||||||
|
return len(sources) > 0 && sources[0].Addr().Is6()
|
||||||
|
}
|
||||||
|
|
||||||
// DeletePeerRule from the firewall by rule definition
|
// DeletePeerRule from the firewall by rule definition
|
||||||
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
if m.hasIPv6() && isIPv6IptRule(rule) {
|
||||||
|
return m.aclMgr6.DeletePeerRule(rule)
|
||||||
|
}
|
||||||
return m.aclMgr.DeletePeerRule(rule)
|
return m.aclMgr.DeletePeerRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isIPv6IptRule(rule firewall.Rule) bool {
|
||||||
|
r, ok := rule.(*Rule)
|
||||||
|
return ok && r.v6
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteRouteRule deletes a routing rule.
|
||||||
|
// Route rules are keyed by content hash. Check v4 first, try v6 if not found.
|
||||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
if m.hasIPv6() && !m.router.hasRule(rule.ID()) {
|
||||||
|
return m.router6.DeleteRouteRule(rule)
|
||||||
|
}
|
||||||
return m.router.DeleteRouteRule(rule)
|
return m.router.DeleteRouteRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -161,18 +263,63 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
return m.router.AddNatRule(pair)
|
if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
|
||||||
|
if !m.hasIPv6() {
|
||||||
|
return fmt.Errorf("IPv6 not initialized, cannot add NAT rule")
|
||||||
|
}
|
||||||
|
return m.router6.AddNatRule(pair)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.router.AddNatRule(pair); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dynamic routes need NAT in both tables
|
||||||
|
if m.hasIPv6() && pair.Destination.IsSet() {
|
||||||
|
v6Pair := pair
|
||||||
|
v6Pair.Source = firewall.Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
||||||
|
if err := m.router6.AddNatRule(v6Pair); err != nil {
|
||||||
|
return fmt.Errorf("add v6 NAT rule: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
return m.router.RemoveNatRule(pair)
|
if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
|
||||||
|
if !m.hasIPv6() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return m.router6.RemoveNatRule(pair)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.router.RemoveNatRule(pair); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.hasIPv6() && pair.Destination.IsSet() {
|
||||||
|
v6Pair := pair
|
||||||
|
v6Pair.Source = firewall.Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
||||||
|
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
|
||||||
|
return fmt.Errorf("remove v6 NAT rule: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||||
return firewall.SetLegacyManagement(m.router, isLegacy)
|
if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if m.hasIPv6() {
|
||||||
|
return firewall.SetLegacyManagement(m.router6, isLegacy)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
@@ -186,6 +333,15 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if m.hasIPv6() {
|
||||||
|
if err := m.aclMgr6.Reset(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("reset v6 acl manager: %w", err))
|
||||||
|
}
|
||||||
|
if err := m.router6.Reset(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := m.aclMgr.Reset(); err != nil {
|
if err := m.aclMgr.Reset(); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
|
||||||
}
|
}
|
||||||
@@ -209,19 +365,16 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := m.AddPeerFiltering(
|
var merr *multierror.Error
|
||||||
nil,
|
if _, err := m.aclMgr.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
|
||||||
net.IP{0, 0, 0, 0},
|
merr = multierror.Append(merr, fmt.Errorf("allow netbird interface traffic: %w", err))
|
||||||
firewall.ProtocolALL,
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
firewall.ActionAccept,
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
|
||||||
}
|
}
|
||||||
return nil
|
if m.hasIPv6() {
|
||||||
|
if _, err := m.aclMgr6.AddPeerFiltering(nil, net.IPv6zero, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("allow v6 netbird interface traffic: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush doesn't need to be implemented for this manager
|
// Flush doesn't need to be implemented for this manager
|
||||||
@@ -251,6 +404,9 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
if m.hasIPv6() && rule.TranslatedAddress.Is6() {
|
||||||
|
return m.router6.AddDNATRule(rule)
|
||||||
|
}
|
||||||
return m.router.AddDNATRule(rule)
|
return m.router.AddDNATRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -259,6 +415,9 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
if m.hasIPv6() && !m.router.hasRule(rule.ID()+dnatSuffix) {
|
||||||
|
return m.router6.DeleteDNATRule(rule)
|
||||||
|
}
|
||||||
return m.router.DeleteDNATRule(rule)
|
return m.router.DeleteDNATRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -267,7 +426,26 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
return m.router.UpdateSet(set, prefixes)
|
var v4Prefixes, v6Prefixes []netip.Prefix
|
||||||
|
for _, p := range prefixes {
|
||||||
|
if p.Addr().Is6() {
|
||||||
|
v6Prefixes = append(v6Prefixes, p)
|
||||||
|
} else {
|
||||||
|
v4Prefixes = append(v4Prefixes, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.hasIPv6() && len(v6Prefixes) > 0 {
|
||||||
|
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
|
||||||
|
return fmt.Errorf("update v6 set: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||||
@@ -275,6 +453,9 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
if m.hasIPv6() && localAddr.Is6() {
|
||||||
|
return m.router6.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||||
|
}
|
||||||
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -283,6 +464,9 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
if m.hasIPv6() && localAddr.Is6() {
|
||||||
|
return m.router6.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||||
|
}
|
||||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -54,8 +54,10 @@ const (
|
|||||||
snatSuffix = "_snat"
|
snatSuffix = "_snat"
|
||||||
fwdSuffix = "_fwd"
|
fwdSuffix = "_fwd"
|
||||||
|
|
||||||
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
// ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
|
||||||
ipTCPHeaderMinSize = 40
|
ipv4TCPHeaderSize = 40
|
||||||
|
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
|
||||||
|
ipv6TCPHeaderSize = 60
|
||||||
)
|
)
|
||||||
|
|
||||||
type ruleInfo struct {
|
type ruleInfo struct {
|
||||||
@@ -86,6 +88,7 @@ type router struct {
|
|||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
legacyManagement bool
|
legacyManagement bool
|
||||||
mtu uint16
|
mtu uint16
|
||||||
|
v6 bool
|
||||||
|
|
||||||
stateManager *statemanager.Manager
|
stateManager *statemanager.Manager
|
||||||
ipFwdState *ipfwdstate.IPForwardingState
|
ipFwdState *ipfwdstate.IPForwardingState
|
||||||
@@ -97,6 +100,7 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint1
|
|||||||
rules: make(map[string][]string),
|
rules: make(map[string][]string),
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
mtu: mtu,
|
mtu: mtu,
|
||||||
|
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
|
||||||
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -186,6 +190,11 @@ func (r *router) AddRouteFiltering(
|
|||||||
return ruleKey, nil
|
return ruleKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *router) hasRule(id string) bool {
|
||||||
|
_, ok := r.rules[id]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||||
ruleKey := rule.ID()
|
ruleKey := rule.ID()
|
||||||
|
|
||||||
@@ -434,6 +443,12 @@ func (r *router) createContainers() error {
|
|||||||
{chainRTRDR, tableNat},
|
{chainRTRDR, tableNat},
|
||||||
{chainRTMSSCLAMP, tableMangle},
|
{chainRTMSSCLAMP, tableMangle},
|
||||||
} {
|
} {
|
||||||
|
// Fallback: clear chains that survived an unclean shutdown.
|
||||||
|
if ok, _ := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain); ok {
|
||||||
|
if err := r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||||
|
log.Warnf("clear stale chain %s in %s: %v", chainInfo.chain, chainInfo.table, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
|
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||||
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||||
}
|
}
|
||||||
@@ -540,9 +555,12 @@ func (r *router) addPostroutingRules() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
||||||
// TODO: Add IPv6 support
|
|
||||||
func (r *router) addMSSClampingRules() error {
|
func (r *router) addMSSClampingRules() error {
|
||||||
mss := r.mtu - ipTCPHeaderMinSize
|
overhead := uint16(ipv4TCPHeaderSize)
|
||||||
|
if r.v6 {
|
||||||
|
overhead = ipv6TCPHeaderSize
|
||||||
|
}
|
||||||
|
mss := r.mtu - overhead
|
||||||
|
|
||||||
// Add jump rule from FORWARD chain in mangle table to our custom chain
|
// Add jump rule from FORWARD chain in mangle table to our custom chain
|
||||||
jumpRule := []string{
|
jumpRule := []string{
|
||||||
@@ -727,8 +745,13 @@ func (r *router) updateState() {
|
|||||||
currentState.Lock()
|
currentState.Lock()
|
||||||
defer currentState.Unlock()
|
defer currentState.Unlock()
|
||||||
|
|
||||||
currentState.RouteRules = r.rules
|
if r.v6 {
|
||||||
currentState.RouteIPsetCounter = r.ipsetCounter
|
currentState.RouteRules6 = r.rules
|
||||||
|
currentState.RouteIPsetCounter6 = r.ipsetCounter
|
||||||
|
} else {
|
||||||
|
currentState.RouteRules = r.rules
|
||||||
|
currentState.RouteIPsetCounter = r.ipsetCounter
|
||||||
|
}
|
||||||
|
|
||||||
if err := r.stateManager.UpdateState(currentState); err != nil {
|
if err := r.stateManager.UpdateState(currentState); err != nil {
|
||||||
log.Errorf("failed to update state: %v", err)
|
log.Errorf("failed to update state: %v", err)
|
||||||
@@ -856,7 +879,7 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists {
|
if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists {
|
||||||
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, fwdRule...); err != nil {
|
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDOUT, fwdRule...); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err))
|
||||||
}
|
}
|
||||||
delete(r.rules, ruleKey+fwdSuffix)
|
delete(r.rules, ruleKey+fwdSuffix)
|
||||||
@@ -883,7 +906,7 @@ func (r *router) genRouteRuleSpec(params routeFilteringRuleParams, sources []net
|
|||||||
rule = append(rule, destExp...)
|
rule = append(rule, destExp...)
|
||||||
|
|
||||||
if params.Proto != firewall.ProtocolALL {
|
if params.Proto != firewall.ProtocolALL {
|
||||||
rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
|
rule = append(rule, "-p", strings.ToLower(protoForFamily(params.Proto, r.v6)))
|
||||||
rule = append(rule, applyPort("--sport", params.SPort)...)
|
rule = append(rule, applyPort("--sport", params.SPort)...)
|
||||||
rule = append(rule, applyPort("--dport", params.DPort)...)
|
rule = append(rule, applyPort("--dport", params.DPort)...)
|
||||||
}
|
}
|
||||||
@@ -900,11 +923,12 @@ func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []
|
|||||||
}
|
}
|
||||||
|
|
||||||
if network.IsSet() {
|
if network.IsSet() {
|
||||||
if _, err := r.ipsetCounter.Increment(network.Set.HashedName(), prefixes); err != nil {
|
name := r.ipsetName(network.Set.HashedName())
|
||||||
|
if _, err := r.ipsetCounter.Increment(name, prefixes); err != nil {
|
||||||
return nil, fmt.Errorf("create or get ipset: %w", err)
|
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return []string{"-m", "set", matchSet, network.Set.HashedName(), direction}, nil
|
return []string{"-m", "set", matchSet, name, direction}, nil
|
||||||
}
|
}
|
||||||
if network.IsPrefix() {
|
if network.IsPrefix() {
|
||||||
return []string{flag, network.Prefix.String()}, nil
|
return []string{flag, network.Prefix.String()}, nil
|
||||||
@@ -915,19 +939,15 @@ func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||||
|
name := r.ipsetName(set.HashedName())
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
for _, prefix := range prefixes {
|
for _, prefix := range prefixes {
|
||||||
// TODO: Implement IPv6 support
|
if err := r.addPrefixToIPSet(name, prefix); err != nil {
|
||||||
if prefix.Addr().Is6() {
|
|
||||||
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err := r.addPrefixToIPSet(set.HashedName(), prefix); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if merr == nil {
|
if merr == nil {
|
||||||
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
|
log.Debugf("updated set %s with prefixes %v", name, prefixes)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
@@ -943,7 +963,7 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
|
|||||||
|
|
||||||
dnatRule := []string{
|
dnatRule := []string{
|
||||||
"-i", r.wgIface.Name(),
|
"-i", r.wgIface.Name(),
|
||||||
"-p", strings.ToLower(string(protocol)),
|
"-p", strings.ToLower(protoForFamily(protocol, r.v6)),
|
||||||
"--dport", strconv.Itoa(int(sourcePort)),
|
"--dport", strconv.Itoa(int(sourcePort)),
|
||||||
"-d", localAddr.String(),
|
"-d", localAddr.String(),
|
||||||
"-m", "addrtype", "--dst-type", "LOCAL",
|
"-m", "addrtype", "--dst-type", "LOCAL",
|
||||||
@@ -1076,10 +1096,22 @@ func applyPort(flag string, port *firewall.Port) []string {
|
|||||||
return []string{flag, strconv.Itoa(int(port.Values[0]))}
|
return []string{flag, strconv.Itoa(int(port.Values[0]))}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ipsetName returns the ipset name, suffixed with "-v6" for the v6 router
|
||||||
|
// to avoid collisions since ipsets are global in the kernel.
|
||||||
|
func (r *router) ipsetName(name string) string {
|
||||||
|
if r.v6 {
|
||||||
|
return name + "-v6"
|
||||||
|
}
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
|
||||||
func (r *router) createIPSet(name string) error {
|
func (r *router) createIPSet(name string) error {
|
||||||
opts := ipset.CreateOptions{
|
opts := ipset.CreateOptions{
|
||||||
Replace: true,
|
Replace: true,
|
||||||
}
|
}
|
||||||
|
if r.v6 {
|
||||||
|
opts.Family = ipset.FamilyIPV6
|
||||||
|
}
|
||||||
|
|
||||||
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
||||||
return fmt.Errorf("create ipset %s: %w", name, err)
|
return fmt.Errorf("create ipset %s: %w", name, err)
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ type Rule struct {
|
|||||||
mangleSpecs []string
|
mangleSpecs []string
|
||||||
ip string
|
ip string
|
||||||
chain string
|
chain string
|
||||||
|
v6 bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
// GetRuleID returns the rule id
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
@@ -37,6 +39,12 @@ type ShutdownState struct {
|
|||||||
|
|
||||||
ACLEntries aclEntries `json:"acl_entries,omitempty"`
|
ACLEntries aclEntries `json:"acl_entries,omitempty"`
|
||||||
ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"`
|
ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"`
|
||||||
|
|
||||||
|
// IPv6 counterparts
|
||||||
|
RouteRules6 routeRules `json:"route_rules_v6,omitempty"`
|
||||||
|
RouteIPsetCounter6 *ipsetCounter `json:"route_ipset_counter_v6,omitempty"`
|
||||||
|
ACLEntries6 aclEntries `json:"acl_entries_v6,omitempty"`
|
||||||
|
ACLIPsetStore6 *ipsetStore `json:"acl_ipset_store_v6,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ShutdownState) Name() string {
|
func (s *ShutdownState) Name() string {
|
||||||
@@ -67,6 +75,28 @@ func (s *ShutdownState) Cleanup() error {
|
|||||||
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
|
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clean up v6 state even if the current run has no IPv6.
|
||||||
|
// The previous run may have left ip6tables rules behind.
|
||||||
|
if !ipt.hasIPv6() {
|
||||||
|
if err := ipt.createIPv6Components(s.InterfaceState, mtu); err != nil {
|
||||||
|
log.Warnf("failed to create v6 components for cleanup: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ipt.hasIPv6() {
|
||||||
|
if s.RouteRules6 != nil {
|
||||||
|
ipt.router6.rules = s.RouteRules6
|
||||||
|
}
|
||||||
|
if s.RouteIPsetCounter6 != nil {
|
||||||
|
ipt.router6.ipsetCounter.LoadData(s.RouteIPsetCounter6)
|
||||||
|
}
|
||||||
|
if s.ACLEntries6 != nil {
|
||||||
|
ipt.aclMgr6.entries = s.ACLEntries6
|
||||||
|
}
|
||||||
|
if s.ACLIPsetStore6 != nil {
|
||||||
|
ipt.aclMgr6.ipsetStore = s.ACLIPsetStore6
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := ipt.Close(nil); err != nil {
|
if err := ipt.Close(nil); err != nil {
|
||||||
return fmt.Errorf("reset iptables manager: %w", err)
|
return fmt.Errorf("reset iptables manager: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,15 +33,12 @@ const (
|
|||||||
|
|
||||||
const flushError = "flush: %w"
|
const flushError = "flush: %w"
|
||||||
|
|
||||||
var (
|
|
||||||
anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
|
||||||
)
|
|
||||||
|
|
||||||
type AclManager struct {
|
type AclManager struct {
|
||||||
rConn *nftables.Conn
|
rConn *nftables.Conn
|
||||||
sConn *nftables.Conn
|
sConn *nftables.Conn
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
routingFwChainName string
|
routingFwChainName string
|
||||||
|
af addrFamily
|
||||||
|
|
||||||
workTable *nftables.Table
|
workTable *nftables.Table
|
||||||
chainInputRules *nftables.Chain
|
chainInputRules *nftables.Chain
|
||||||
@@ -67,6 +64,7 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam
|
|||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
workTable: table,
|
workTable: table,
|
||||||
routingFwChainName: routingFwChainName,
|
routingFwChainName: routingFwChainName,
|
||||||
|
af: familyForAddr(table.Family == nftables.TableFamilyIPv4),
|
||||||
|
|
||||||
ipsetStore: newIpsetStore(),
|
ipsetStore: newIpsetStore(),
|
||||||
rules: make(map[string]*Rule),
|
rules: make(map[string]*Rule),
|
||||||
@@ -145,7 +143,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := ips[r.ip.String()]; ok {
|
if _, ok := ips[r.ip.String()]; ok {
|
||||||
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: r.ip.To4()}})
|
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: ipToBytes(r.ip, m.af)}})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("delete elements for set %q: %v", r.nftSet.Name, err)
|
log.Errorf("delete elements for set %q: %v", r.nftSet.Name, err)
|
||||||
}
|
}
|
||||||
@@ -254,11 +252,11 @@ func (m *AclManager) addIOFiltering(
|
|||||||
expressions = append(expressions, &expr.Payload{
|
expressions = append(expressions, &expr.Payload{
|
||||||
DestRegister: 1,
|
DestRegister: 1,
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
Offset: uint32(9),
|
Offset: m.af.protoOffset,
|
||||||
Len: uint32(1),
|
Len: uint32(1),
|
||||||
})
|
})
|
||||||
|
|
||||||
protoData, err := protoToInt(proto)
|
protoData, err := m.af.protoNum(proto)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("convert protocol to number: %v", err)
|
return nil, fmt.Errorf("convert protocol to number: %v", err)
|
||||||
}
|
}
|
||||||
@@ -270,19 +268,16 @@ func (m *AclManager) addIOFiltering(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
rawIP := ip.To4()
|
rawIP := ipToBytes(ip, m.af)
|
||||||
// check if rawIP contains zeroed IPv4 0.0.0.0 value
|
// check if rawIP contains zeroed IPv4 0.0.0.0 value
|
||||||
// in that case not add IP match expression into the rule definition
|
// in that case not add IP match expression into the rule definition
|
||||||
if !bytes.HasPrefix(anyIP, rawIP) {
|
if slices.ContainsFunc(rawIP, func(v byte) bool { return v != 0 }) {
|
||||||
// source address position
|
|
||||||
addrOffset := uint32(12)
|
|
||||||
|
|
||||||
expressions = append(expressions,
|
expressions = append(expressions,
|
||||||
&expr.Payload{
|
&expr.Payload{
|
||||||
DestRegister: 1,
|
DestRegister: 1,
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
Offset: addrOffset,
|
Offset: m.af.srcAddrOffset,
|
||||||
Len: 4,
|
Len: m.af.addrLen,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
// add individual IP for match if no ipset defined
|
// add individual IP for match if no ipset defined
|
||||||
@@ -587,7 +582,7 @@ func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr
|
|||||||
|
|
||||||
func (m *AclManager) addIpToSet(ipsetName string, ip net.IP) (*nftables.Set, error) {
|
func (m *AclManager) addIpToSet(ipsetName string, ip net.IP) (*nftables.Set, error) {
|
||||||
ipset, err := m.rConn.GetSetByName(m.workTable, ipsetName)
|
ipset, err := m.rConn.GetSetByName(m.workTable, ipsetName)
|
||||||
rawIP := ip.To4()
|
rawIP := ipToBytes(ip, m.af)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if ipset, err = m.createSet(m.workTable, ipsetName); err != nil {
|
if ipset, err = m.createSet(m.workTable, ipsetName); err != nil {
|
||||||
return nil, fmt.Errorf("get set name: %v", err)
|
return nil, fmt.Errorf("get set name: %v", err)
|
||||||
@@ -619,7 +614,7 @@ func (m *AclManager) createSet(table *nftables.Table, name string) (*nftables.Se
|
|||||||
Name: name,
|
Name: name,
|
||||||
Table: table,
|
Table: table,
|
||||||
Dynamic: true,
|
Dynamic: true,
|
||||||
KeyType: nftables.TypeIPAddr,
|
KeyType: m.af.setKeyType,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.rConn.AddSet(ipset, nil); err != nil {
|
if err := m.rConn.AddSet(ipset, nil); err != nil {
|
||||||
@@ -707,15 +702,12 @@ func ifname(n string) []byte {
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
func protoToInt(protocol firewall.Protocol) (uint8, error) {
|
|
||||||
switch protocol {
|
|
||||||
case firewall.ProtocolTCP:
|
|
||||||
return unix.IPPROTO_TCP, nil
|
|
||||||
case firewall.ProtocolUDP:
|
|
||||||
return unix.IPPROTO_UDP, nil
|
|
||||||
case firewall.ProtocolICMP:
|
|
||||||
return unix.IPPROTO_ICMP, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
|
// ipToBytes converts net.IP to the correct byte length for the address family.
|
||||||
|
func ipToBytes(ip net.IP, af addrFamily) []byte {
|
||||||
|
if af.addrLen == 4 {
|
||||||
|
return ip.To4()
|
||||||
|
}
|
||||||
|
return ip.To16()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
81
client/firewall/nftables/addr_family_linux.go
Normal file
81
client/firewall/nftables/addr_family_linux.go
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/google/nftables"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// afIPv4 defines IPv4 header layout and nftables types.
|
||||||
|
afIPv4 = addrFamily{
|
||||||
|
protoOffset: 9,
|
||||||
|
srcAddrOffset: 12,
|
||||||
|
dstAddrOffset: 16,
|
||||||
|
addrLen: net.IPv4len,
|
||||||
|
totalBits: 8 * net.IPv4len,
|
||||||
|
setKeyType: nftables.TypeIPAddr,
|
||||||
|
tableFamily: nftables.TableFamilyIPv4,
|
||||||
|
icmpProto: unix.IPPROTO_ICMP,
|
||||||
|
}
|
||||||
|
// afIPv6 defines IPv6 header layout and nftables types.
|
||||||
|
afIPv6 = addrFamily{
|
||||||
|
protoOffset: 6,
|
||||||
|
srcAddrOffset: 8,
|
||||||
|
dstAddrOffset: 24,
|
||||||
|
addrLen: net.IPv6len,
|
||||||
|
totalBits: 8 * net.IPv6len,
|
||||||
|
setKeyType: nftables.TypeIP6Addr,
|
||||||
|
tableFamily: nftables.TableFamilyIPv6,
|
||||||
|
icmpProto: unix.IPPROTO_ICMPV6,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// addrFamily holds protocol-specific constants for nftables expression building.
|
||||||
|
type addrFamily struct {
|
||||||
|
// protoOffset is the IP header offset for the protocol/next-header field (9 for v4, 6 for v6)
|
||||||
|
protoOffset uint32
|
||||||
|
// srcAddrOffset is the IP header offset for the source address (12 for v4, 8 for v6)
|
||||||
|
srcAddrOffset uint32
|
||||||
|
// dstAddrOffset is the IP header offset for the destination address (16 for v4, 24 for v6)
|
||||||
|
dstAddrOffset uint32
|
||||||
|
// addrLen is the byte length of addresses (4 for v4, 16 for v6)
|
||||||
|
addrLen uint32
|
||||||
|
// totalBits is the address size in bits (32 for v4, 128 for v6)
|
||||||
|
totalBits int
|
||||||
|
// setKeyType is the nftables set data type for addresses
|
||||||
|
setKeyType nftables.SetDatatype
|
||||||
|
// tableFamily is the nftables table family
|
||||||
|
tableFamily nftables.TableFamily
|
||||||
|
// icmpProto is the ICMP protocol number for this family (1 for v4, 58 for v6)
|
||||||
|
icmpProto uint8
|
||||||
|
}
|
||||||
|
|
||||||
|
// familyForAddr returns the address family for the given IP.
|
||||||
|
func familyForAddr(is4 bool) addrFamily {
|
||||||
|
if is4 {
|
||||||
|
return afIPv4
|
||||||
|
}
|
||||||
|
return afIPv6
|
||||||
|
}
|
||||||
|
|
||||||
|
// protoNum converts a firewall protocol to the IP protocol number,
|
||||||
|
// using the correct ICMP variant for the address family.
|
||||||
|
func (af addrFamily) protoNum(protocol firewall.Protocol) (uint8, error) {
|
||||||
|
switch protocol {
|
||||||
|
case firewall.ProtocolTCP:
|
||||||
|
return unix.IPPROTO_TCP, nil
|
||||||
|
case firewall.ProtocolUDP:
|
||||||
|
return unix.IPPROTO_UDP, nil
|
||||||
|
case firewall.ProtocolICMP:
|
||||||
|
return af.icmpProto, nil
|
||||||
|
case firewall.ProtocolALL:
|
||||||
|
return 0, nil
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -11,9 +11,11 @@ import (
|
|||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
"github.com/google/nftables/binaryutil"
|
"github.com/google/nftables/binaryutil"
|
||||||
"github.com/google/nftables/expr"
|
"github.com/google/nftables/expr"
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
@@ -49,8 +51,13 @@ type Manager struct {
|
|||||||
rConn *nftables.Conn
|
rConn *nftables.Conn
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
|
|
||||||
router *router
|
router *router
|
||||||
aclManager *AclManager
|
aclManager *AclManager
|
||||||
|
|
||||||
|
// IPv6 counterparts, nil when no v6 overlay
|
||||||
|
router6 *router
|
||||||
|
aclManager6 *AclManager
|
||||||
|
|
||||||
notrackOutputChain *nftables.Chain
|
notrackOutputChain *nftables.Chain
|
||||||
notrackPreroutingChain *nftables.Chain
|
notrackPreroutingChain *nftables.Chain
|
||||||
}
|
}
|
||||||
@@ -62,7 +69,8 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
|||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
}
|
}
|
||||||
|
|
||||||
workTable := &nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4}
|
tableName := getTableName()
|
||||||
|
workTable := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
m.router, err = newRouter(workTable, wgIface, mtu)
|
m.router, err = newRouter(workTable, wgIface, mtu)
|
||||||
@@ -75,11 +83,70 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
|||||||
return nil, fmt.Errorf("create acl manager: %w", err)
|
return nil, fmt.Errorf("create acl manager: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if wgIface.Address().HasIPv6() {
|
||||||
|
if err := m.createIPv6Components(tableName, wgIface, mtu); err != nil {
|
||||||
|
return nil, fmt.Errorf("create IPv6 firewall: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) createIPv6Components(tableName string, wgIface iFaceMapper, mtu uint16) error {
|
||||||
|
workTable6 := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv6}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
m.router6, err = newRouter(workTable6, wgIface, mtu)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create v6 router: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Share the same IP forwarding state with the v4 router, since
|
||||||
|
// EnableIPForwarding controls both v4 and v6 sysctls.
|
||||||
|
m.router6.ipFwdState = m.router.ipFwdState
|
||||||
|
|
||||||
|
m.aclManager6, err = newAclManager(workTable6, wgIface, chainNameRoutingFw)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create v6 acl manager: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasIPv6 reports whether the manager has IPv6 components initialized.
|
||||||
|
func (m *Manager) hasIPv6() bool {
|
||||||
|
return m.router6 != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) initIPv6() error {
|
||||||
|
workTable6, err := m.createWorkTableFamily(nftables.TableFamilyIPv6)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create v6 work table: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.router6.init(workTable6); err != nil {
|
||||||
|
return fmt.Errorf("v6 router init: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.aclManager6.init(workTable6); err != nil {
|
||||||
|
return fmt.Errorf("v6 acl manager init: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Init nftables firewall manager
|
// Init nftables firewall manager
|
||||||
func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||||
|
if err := m.initFirewall(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.persistState(stateManager)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) initFirewall() error {
|
||||||
workTable, err := m.createWorkTable()
|
workTable, err := m.createWorkTable()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create work table: %w", err)
|
return fmt.Errorf("create work table: %w", err)
|
||||||
@@ -90,20 +157,32 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := m.aclManager.init(workTable); err != nil {
|
if err := m.aclManager.init(workTable); err != nil {
|
||||||
// TODO: cleanup router
|
m.rollbackInit()
|
||||||
return fmt.Errorf("acl manager init: %w", err)
|
return fmt.Errorf("acl manager init: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if m.hasIPv6() {
|
||||||
|
if err := m.initIPv6(); err != nil {
|
||||||
|
// Peer has a v6 address: v6 firewall MUST work or we risk fail-open.
|
||||||
|
m.rollbackInit()
|
||||||
|
return fmt.Errorf("init IPv6 firewall (required because peer has IPv6 address): %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := m.initNoTrackChains(workTable); err != nil {
|
if err := m.initNoTrackChains(workTable); err != nil {
|
||||||
log.Warnf("raw priority chains not available, notrack rules will be disabled: %v", err)
|
log.Warnf("raw priority chains not available, notrack rules will be disabled: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// persistState saves the current interface state for potential recreation on restart.
|
||||||
|
// Unlike iptables, which requires tracking individual rules, nftables maintains
|
||||||
|
// a known state (our netbird table plus a few static rules). This allows for easy
|
||||||
|
// cleanup using Close() without needing to store specific rules.
|
||||||
|
func (m *Manager) persistState(stateManager *statemanager.Manager) {
|
||||||
stateManager.RegisterState(&ShutdownState{})
|
stateManager.RegisterState(&ShutdownState{})
|
||||||
|
|
||||||
// We only need to record minimal interface state for potential recreation.
|
|
||||||
// Unlike iptables, which requires tracking individual rules, nftables maintains
|
|
||||||
// a known state (our netbird table plus a few static rules). This allows for easy
|
|
||||||
// cleanup using Close() without needing to store specific rules.
|
|
||||||
if err := stateManager.UpdateState(&ShutdownState{
|
if err := stateManager.UpdateState(&ShutdownState{
|
||||||
InterfaceState: &InterfaceState{
|
InterfaceState: &InterfaceState{
|
||||||
NameStr: m.wgIface.Name(),
|
NameStr: m.wgIface.Name(),
|
||||||
@@ -115,14 +194,29 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
log.Errorf("failed to update state: %v", err)
|
log.Errorf("failed to update state: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// persist early
|
|
||||||
go func() {
|
go func() {
|
||||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||||
log.Errorf("failed to persist state: %v", err)
|
log.Errorf("failed to persist state: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
// rollbackInit performs best-effort cleanup of already-initialized state when Init fails partway through.
|
||||||
|
func (m *Manager) rollbackInit() {
|
||||||
|
if err := m.router.Reset(); err != nil {
|
||||||
|
log.Warnf("rollback router: %v", err)
|
||||||
|
}
|
||||||
|
if m.hasIPv6() {
|
||||||
|
if err := m.router6.Reset(); err != nil {
|
||||||
|
log.Warnf("rollback v6 router: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := m.cleanupNetbirdTables(); err != nil {
|
||||||
|
log.Warnf("cleanup tables: %v", err)
|
||||||
|
}
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
|
log.Warnf("flush: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddPeerFiltering rule to the firewall
|
// AddPeerFiltering rule to the firewall
|
||||||
@@ -141,12 +235,14 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
rawIP := ip.To4()
|
if ip.To4() != nil {
|
||||||
if rawIP == nil {
|
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||||
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
if !m.hasIPv6() {
|
||||||
|
return nil, fmt.Errorf("IPv6 not initialized, cannot add rule for %s", ip)
|
||||||
|
}
|
||||||
|
return m.aclManager6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddRouteFiltering(
|
func (m *Manager) AddRouteFiltering(
|
||||||
@@ -160,8 +256,11 @@ func (m *Manager) AddRouteFiltering(
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
|
if isIPv6RouteRule(sources, destination) {
|
||||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
|
if !m.hasIPv6() {
|
||||||
|
return nil, fmt.Errorf("IPv6 not initialized, cannot add route rule")
|
||||||
|
}
|
||||||
|
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
@@ -172,14 +271,38 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
if m.hasIPv6() && isIPv6Rule(rule) {
|
||||||
|
return m.aclManager6.DeletePeerRule(rule)
|
||||||
|
}
|
||||||
return m.aclManager.DeletePeerRule(rule)
|
return m.aclManager.DeletePeerRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteRouteRule deletes a routing rule
|
func isIPv6Rule(rule firewall.Rule) bool {
|
||||||
|
r, ok := rule.(*Rule)
|
||||||
|
return ok && r.nftRule != nil && r.nftRule.Table != nil && r.nftRule.Table.Family == nftables.TableFamilyIPv6
|
||||||
|
}
|
||||||
|
|
||||||
|
// isIPv6RouteRule determines whether a route rule belongs to the v6 table.
|
||||||
|
// For static routes, the destination prefix determines the family. For dynamic
|
||||||
|
// routes (DomainSet), the sources determine the family since management
|
||||||
|
// duplicates dynamic rules per family.
|
||||||
|
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
|
||||||
|
if destination.IsPrefix() {
|
||||||
|
return destination.Prefix.Addr().Is6()
|
||||||
|
}
|
||||||
|
return len(sources) > 0 && sources[0].Addr().Is6()
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteRouteRule deletes a routing rule.
|
||||||
|
// Route rules are keyed by content hash, so the rule exists in exactly one
|
||||||
|
// router. We check v4 first; if the key isn't there, try v6.
|
||||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
if m.hasIPv6() && !m.router.hasRule(rule.ID()) {
|
||||||
|
return m.router6.DeleteRouteRule(rule)
|
||||||
|
}
|
||||||
return m.router.DeleteRouteRule(rule)
|
return m.router.DeleteRouteRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -195,17 +318,63 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
return m.router.AddNatRule(pair)
|
if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
|
||||||
|
if !m.hasIPv6() {
|
||||||
|
return fmt.Errorf("IPv6 not initialized, cannot add NAT rule")
|
||||||
|
}
|
||||||
|
return m.router6.AddNatRule(pair)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.router.AddNatRule(pair); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dynamic routes (DomainSet) need NAT in both tables since resolved IPs
|
||||||
|
// can be either v4 or v6.
|
||||||
|
if m.hasIPv6() && pair.Destination.IsSet() {
|
||||||
|
v6Pair := pair
|
||||||
|
v6Pair.Source = firewall.Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
||||||
|
if err := m.router6.AddNatRule(v6Pair); err != nil {
|
||||||
|
return fmt.Errorf("add v6 NAT rule: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
return m.router.RemoveNatRule(pair)
|
if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
|
||||||
|
if !m.hasIPv6() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return m.router6.RemoveNatRule(pair)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.router.RemoveNatRule(pair); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.hasIPv6() && pair.Destination.IsSet() {
|
||||||
|
v6Pair := pair
|
||||||
|
v6Pair.Source = firewall.Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
||||||
|
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
|
||||||
|
return fmt.Errorf("remove v6 NAT rule: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AllowNetbird allows netbird interface traffic
|
// AllowNetbird allows netbird interface traffic.
|
||||||
|
// TODO: In USP mode this only adds ACCEPT to the netbird table's own chains,
|
||||||
|
// which doesn't override DROP rules in external tables (e.g. firewalld).
|
||||||
|
// Should add passthrough rules to external chains (like the native mode router's
|
||||||
|
// addExternalChainsRules does) for both the netbird table family and inet tables.
|
||||||
|
// The netbird table itself is fine (routing chains already exist there), but
|
||||||
|
// non-netbird tables with INPUT/FORWARD hooks can still DROP our WG traffic.
|
||||||
func (m *Manager) AllowNetbird() error {
|
func (m *Manager) AllowNetbird() error {
|
||||||
if !m.wgIface.IsUserspaceBind() {
|
if !m.wgIface.IsUserspaceBind() {
|
||||||
return nil
|
return nil
|
||||||
@@ -217,6 +386,11 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
if err := m.aclManager.createDefaultAllowRules(); err != nil {
|
if err := m.aclManager.createDefaultAllowRules(); err != nil {
|
||||||
return fmt.Errorf("create default allow rules: %w", err)
|
return fmt.Errorf("create default allow rules: %w", err)
|
||||||
}
|
}
|
||||||
|
if m.hasIPv6() {
|
||||||
|
if err := m.aclManager6.createDefaultAllowRules(); err != nil {
|
||||||
|
return fmt.Errorf("create v6 default allow rules: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
if err := m.rConn.Flush(); err != nil {
|
if err := m.rConn.Flush(); err != nil {
|
||||||
return fmt.Errorf("flush allow input netbird rules: %w", err)
|
return fmt.Errorf("flush allow input netbird rules: %w", err)
|
||||||
}
|
}
|
||||||
@@ -226,7 +400,13 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
|
|
||||||
// SetLegacyManagement sets the route manager to use legacy management
|
// SetLegacyManagement sets the route manager to use legacy management
|
||||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||||
return firewall.SetLegacyManagement(m.router, isLegacy)
|
if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if m.hasIPv6() {
|
||||||
|
return firewall.SetLegacyManagement(m.router6, isLegacy)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the firewall manager
|
// Close closes the firewall manager
|
||||||
@@ -234,23 +414,31 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
if err := m.router.Reset(); err != nil {
|
if err := m.router.Reset(); err != nil {
|
||||||
return fmt.Errorf("reset router: %v", err)
|
merr = multierror.Append(merr, fmt.Errorf("reset router: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.hasIPv6() {
|
||||||
|
if err := m.router6.Reset(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %v", err))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.cleanupNetbirdTables(); err != nil {
|
if err := m.cleanupNetbirdTables(); err != nil {
|
||||||
return fmt.Errorf("cleanup netbird tables: %v", err)
|
merr = multierror.Append(merr, fmt.Errorf("cleanup netbird tables: %v", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.rConn.Flush(); err != nil {
|
if err := m.rConn.Flush(); err != nil {
|
||||||
return fmt.Errorf(flushError, err)
|
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
||||||
return fmt.Errorf("delete state: %v", err)
|
merr = multierror.Append(merr, fmt.Errorf("delete state: %v", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) cleanupNetbirdTables() error {
|
func (m *Manager) cleanupNetbirdTables() error {
|
||||||
@@ -299,6 +487,12 @@ func (m *Manager) Flush() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if m.hasIPv6() {
|
||||||
|
if err := m.aclManager6.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("flush v6 acl: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := m.refreshNoTrackChains(); err != nil {
|
if err := m.refreshNoTrackChains(); err != nil {
|
||||||
log.Errorf("failed to refresh notrack chains: %v", err)
|
log.Errorf("failed to refresh notrack chains: %v", err)
|
||||||
}
|
}
|
||||||
@@ -311,6 +505,9 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
if m.hasIPv6() && rule.TranslatedAddress.Is6() {
|
||||||
|
return m.router6.AddDNATRule(rule)
|
||||||
|
}
|
||||||
return m.router.AddDNATRule(rule)
|
return m.router.AddDNATRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -319,6 +516,9 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
if m.hasIPv6() && !m.router.hasDNATRule(rule.ID()) {
|
||||||
|
return m.router6.DeleteDNATRule(rule)
|
||||||
|
}
|
||||||
return m.router.DeleteDNATRule(rule)
|
return m.router.DeleteDNATRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -327,7 +527,26 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
return m.router.UpdateSet(set, prefixes)
|
var v4Prefixes, v6Prefixes []netip.Prefix
|
||||||
|
for _, p := range prefixes {
|
||||||
|
if p.Addr().Is6() {
|
||||||
|
v6Prefixes = append(v6Prefixes, p)
|
||||||
|
} else {
|
||||||
|
v4Prefixes = append(v4Prefixes, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.hasIPv6() && len(v6Prefixes) > 0 {
|
||||||
|
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
|
||||||
|
return fmt.Errorf("update v6 set: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||||
@@ -335,6 +554,9 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
if m.hasIPv6() && localAddr.Is6() {
|
||||||
|
return m.router6.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||||
|
}
|
||||||
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -343,6 +565,9 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
if m.hasIPv6() && localAddr.Is6() {
|
||||||
|
return m.router6.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||||
|
}
|
||||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -533,7 +758,11 @@ func (m *Manager) refreshNoTrackChains() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||||
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
return m.createWorkTableFamily(nftables.TableFamilyIPv4)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) createWorkTableFamily(family nftables.TableFamily) (*nftables.Table, error) {
|
||||||
|
tables, err := m.rConn.ListTablesOfFamily(family)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("list of tables: %w", err)
|
return nil, fmt.Errorf("list of tables: %w", err)
|
||||||
}
|
}
|
||||||
@@ -545,7 +774,7 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
table := m.rConn.AddTable(&nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4})
|
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: family})
|
||||||
err = m.rConn.Flush()
|
err = m.rConn.Flush()
|
||||||
return table, err
|
return table, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -385,10 +385,134 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
|||||||
err = manager.AddNatRule(pair)
|
err = manager.AddNatRule(pair)
|
||||||
require.NoError(t, err, "failed to add NAT rule")
|
require.NoError(t, err, "failed to add NAT rule")
|
||||||
|
|
||||||
|
dnatRule, err := manager.AddDNATRule(fw.ForwardRule{
|
||||||
|
Protocol: fw.ProtocolTCP,
|
||||||
|
DestinationPort: fw.Port{Values: []uint16{8080}},
|
||||||
|
TranslatedAddress: netip.MustParseAddr("100.96.0.2"),
|
||||||
|
TranslatedPort: fw.Port{Values: []uint16{80}},
|
||||||
|
})
|
||||||
|
require.NoError(t, err, "failed to add DNAT rule")
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.DeleteDNATRule(dnatRule), "failed to delete DNAT rule")
|
||||||
|
})
|
||||||
|
|
||||||
stdout, stderr = runIptablesSave(t)
|
stdout, stderr = runIptablesSave(t)
|
||||||
verifyIptablesOutput(t, stdout, stderr)
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNftablesManagerIPv6CompatibilityWithIp6tables(t *testing.T) {
|
||||||
|
if check() != NFTABLES {
|
||||||
|
t.Skip("nftables not supported on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, bin := range []string{"ip6tables", "ip6tables-save", "iptables-save"} {
|
||||||
|
if _, err := exec.LookPath(bin); err != nil {
|
||||||
|
t.Skipf("%s not available on this system: %v", bin, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Seed ip6 tables in the nft backend. Docker may not create them.
|
||||||
|
seedIp6tables(t)
|
||||||
|
|
||||||
|
ifaceMockV6 := &iFaceMock{
|
||||||
|
NameFunc: func() string { return "wt-test" },
|
||||||
|
AddressFunc: func() wgaddr.Address {
|
||||||
|
return wgaddr.Address{
|
||||||
|
IP: netip.MustParseAddr("100.96.0.1"),
|
||||||
|
Network: netip.MustParsePrefix("100.96.0.0/16"),
|
||||||
|
IPv6: netip.MustParseAddr("fd00::1"),
|
||||||
|
IPv6Net: netip.MustParsePrefix("fd00::/64"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMockV6, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err, "create manager")
|
||||||
|
require.NoError(t, manager.Init(nil))
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Close(nil), "close manager")
|
||||||
|
|
||||||
|
stdout, stderr := runIp6tablesSave(t)
|
||||||
|
verifyIp6tablesOutput(t, stdout, stderr)
|
||||||
|
})
|
||||||
|
|
||||||
|
ip := netip.MustParseAddr("fd00::2")
|
||||||
|
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||||
|
require.NoError(t, err, "add v6 peer filtering rule")
|
||||||
|
|
||||||
|
_, err = manager.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("fd00:1::/64")},
|
||||||
|
fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")},
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []uint16{443}},
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err, "add v6 route filtering rule")
|
||||||
|
|
||||||
|
err = manager.AddNatRule(fw.RouterPair{
|
||||||
|
Source: fw.Network{Prefix: netip.MustParsePrefix("fd00::/64")},
|
||||||
|
Destination: fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")},
|
||||||
|
Masquerade: true,
|
||||||
|
})
|
||||||
|
require.NoError(t, err, "add v6 NAT rule")
|
||||||
|
|
||||||
|
dnatRule, err := manager.AddDNATRule(fw.ForwardRule{
|
||||||
|
Protocol: fw.ProtocolTCP,
|
||||||
|
DestinationPort: fw.Port{Values: []uint16{8080}},
|
||||||
|
TranslatedAddress: netip.MustParseAddr("fd00::2"),
|
||||||
|
TranslatedPort: fw.Port{Values: []uint16{80}},
|
||||||
|
})
|
||||||
|
require.NoError(t, err, "add v6 DNAT rule")
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.DeleteDNATRule(dnatRule), "delete v6 DNAT rule")
|
||||||
|
})
|
||||||
|
|
||||||
|
stdout, stderr := runIptablesSave(t)
|
||||||
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
|
|
||||||
|
stdout, stderr = runIp6tablesSave(t)
|
||||||
|
verifyIp6tablesOutput(t, stdout, stderr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func seedIp6tables(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
for _, tc := range []struct{ table, chain string }{
|
||||||
|
{"filter", "FORWARD"},
|
||||||
|
{"nat", "POSTROUTING"},
|
||||||
|
{"mangle", "FORWARD"},
|
||||||
|
} {
|
||||||
|
add := exec.Command("ip6tables", "-t", tc.table, "-A", tc.chain, "-j", "ACCEPT")
|
||||||
|
require.NoError(t, add.Run(), "seed ip6tables -t %s", tc.table)
|
||||||
|
del := exec.Command("ip6tables", "-t", tc.table, "-D", tc.chain, "-j", "ACCEPT")
|
||||||
|
require.NoError(t, del.Run(), "unseed ip6tables -t %s", tc.table)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func runIp6tablesSave(t *testing.T) (string, string) {
|
||||||
|
t.Helper()
|
||||||
|
var stdout, stderr bytes.Buffer
|
||||||
|
cmd := exec.Command("ip6tables-save")
|
||||||
|
cmd.Stdout = &stdout
|
||||||
|
cmd.Stderr = &stderr
|
||||||
|
require.NoError(t, cmd.Run(), "ip6tables-save failed")
|
||||||
|
return stdout.String(), stderr.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyIp6tablesOutput(t *testing.T, stdout, stderr string) {
|
||||||
|
t.Helper()
|
||||||
|
require.NotContains(t, stdout, "Table `nat' is incompatible",
|
||||||
|
"ip6tables-save: nat table incompatible. Full output: %s", stdout)
|
||||||
|
require.NotContains(t, stdout, "Table `mangle' is incompatible",
|
||||||
|
"ip6tables-save: mangle table incompatible. Full output: %s", stdout)
|
||||||
|
require.NotContains(t, stdout, "Table `filter' is incompatible",
|
||||||
|
"ip6tables-save: filter table incompatible. Full output: %s", stdout)
|
||||||
|
}
|
||||||
|
|
||||||
func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
|
func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
|
||||||
if check() != NFTABLES {
|
if check() != NFTABLES {
|
||||||
t.Skip("nftables not supported on this system")
|
t.Skip("nftables not supported on this system")
|
||||||
|
|||||||
@@ -47,8 +47,10 @@ const (
|
|||||||
dnatSuffix = "_dnat"
|
dnatSuffix = "_dnat"
|
||||||
snatSuffix = "_snat"
|
snatSuffix = "_snat"
|
||||||
|
|
||||||
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
// ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
|
||||||
ipTCPHeaderMinSize = 40
|
ipv4TCPHeaderSize = 40
|
||||||
|
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
|
||||||
|
ipv6TCPHeaderSize = 60
|
||||||
|
|
||||||
// maxPrefixesSet 1638 prefixes start to fail, taking some margin
|
// maxPrefixesSet 1638 prefixes start to fail, taking some margin
|
||||||
maxPrefixesSet = 1500
|
maxPrefixesSet = 1500
|
||||||
@@ -73,6 +75,7 @@ type router struct {
|
|||||||
rules map[string]*nftables.Rule
|
rules map[string]*nftables.Rule
|
||||||
ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set]
|
ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set]
|
||||||
|
|
||||||
|
af addrFamily
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
ipFwdState *ipfwdstate.IPForwardingState
|
ipFwdState *ipfwdstate.IPForwardingState
|
||||||
legacyManagement bool
|
legacyManagement bool
|
||||||
@@ -85,6 +88,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou
|
|||||||
workTable: workTable,
|
workTable: workTable,
|
||||||
chains: make(map[string]*nftables.Chain),
|
chains: make(map[string]*nftables.Chain),
|
||||||
rules: make(map[string]*nftables.Rule),
|
rules: make(map[string]*nftables.Rule),
|
||||||
|
af: familyForAddr(workTable.Family == nftables.TableFamilyIPv4),
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||||
mtu: mtu,
|
mtu: mtu,
|
||||||
@@ -143,7 +147,7 @@ func (r *router) Reset() error {
|
|||||||
func (r *router) removeNatPreroutingRules() error {
|
func (r *router) removeNatPreroutingRules() error {
|
||||||
table := &nftables.Table{
|
table := &nftables.Table{
|
||||||
Name: tableNat,
|
Name: tableNat,
|
||||||
Family: nftables.TableFamilyIPv4,
|
Family: r.af.tableFamily,
|
||||||
}
|
}
|
||||||
chain := &nftables.Chain{
|
chain := &nftables.Chain{
|
||||||
Name: chainNameNatPrerouting,
|
Name: chainNameNatPrerouting,
|
||||||
@@ -176,7 +180,7 @@ func (r *router) removeNatPreroutingRules() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||||
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
tables, err := r.conn.ListTablesOfFamily(r.af.tableFamily)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("list tables: %w", err)
|
return nil, fmt.Errorf("list tables: %w", err)
|
||||||
}
|
}
|
||||||
@@ -408,7 +412,7 @@ func (r *router) AddRouteFiltering(
|
|||||||
|
|
||||||
// Handle protocol
|
// Handle protocol
|
||||||
if proto != firewall.ProtocolALL {
|
if proto != firewall.ProtocolALL {
|
||||||
protoNum, err := protoToInt(proto)
|
protoNum, err := r.af.protoNum(proto)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
||||||
}
|
}
|
||||||
@@ -468,7 +472,24 @@ func (r *router) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bo
|
|||||||
return nil, fmt.Errorf("create or get ipset: %w", err)
|
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return getIpSetExprs(ref, isSource)
|
return r.getIpSetExprs(ref, isSource)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) iptablesProto() iptables.Protocol {
|
||||||
|
if r.af.tableFamily == nftables.TableFamilyIPv6 {
|
||||||
|
return iptables.ProtocolIPv6
|
||||||
|
}
|
||||||
|
return iptables.ProtocolIPv4
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) hasRule(id string) bool {
|
||||||
|
_, ok := r.rules[id]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) hasDNATRule(id string) bool {
|
||||||
|
_, ok := r.rules[id+dnatSuffix]
|
||||||
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||||
@@ -517,10 +538,10 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err
|
|||||||
Table: r.workTable,
|
Table: r.workTable,
|
||||||
// required for prefixes
|
// required for prefixes
|
||||||
Interval: true,
|
Interval: true,
|
||||||
KeyType: nftables.TypeIPAddr,
|
KeyType: r.af.setKeyType,
|
||||||
}
|
}
|
||||||
|
|
||||||
elements := convertPrefixesToSet(prefixes)
|
elements := r.convertPrefixesToSet(prefixes)
|
||||||
nElements := len(elements)
|
nElements := len(elements)
|
||||||
|
|
||||||
maxElements := maxPrefixesSet * 2
|
maxElements := maxPrefixesSet * 2
|
||||||
@@ -553,23 +574,17 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err
|
|||||||
return nfset, nil
|
return nfset, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
|
func (r *router) convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
|
||||||
var elements []nftables.SetElement
|
var elements []nftables.SetElement
|
||||||
for _, prefix := range prefixes {
|
for _, prefix := range prefixes {
|
||||||
// TODO: Implement IPv6 support
|
|
||||||
if prefix.Addr().Is6() {
|
|
||||||
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// nftables needs half-open intervals [firstIP, lastIP) for prefixes
|
// nftables needs half-open intervals [firstIP, lastIP) for prefixes
|
||||||
// e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc
|
// e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc
|
||||||
firstIP := prefix.Addr()
|
firstIP := prefix.Addr()
|
||||||
lastIP := calculateLastIP(prefix).Next()
|
lastIP := calculateLastIP(prefix).Next()
|
||||||
|
|
||||||
elements = append(elements,
|
elements = append(elements,
|
||||||
// the nft tool also adds a line like this, see https://github.com/google/nftables/issues/247
|
// the nft tool also adds a zero-address IntervalEnd element, see https://github.com/google/nftables/issues/247
|
||||||
// nftables.SetElement{Key: []byte{0, 0, 0, 0}, IntervalEnd: true},
|
// nftables.SetElement{Key: make([]byte, r.af.addrLen), IntervalEnd: true},
|
||||||
nftables.SetElement{Key: firstIP.AsSlice()},
|
nftables.SetElement{Key: firstIP.AsSlice()},
|
||||||
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
|
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
|
||||||
)
|
)
|
||||||
@@ -579,10 +594,20 @@ func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
|
|||||||
|
|
||||||
// calculateLastIP determines the last IP in a given prefix.
|
// calculateLastIP determines the last IP in a given prefix.
|
||||||
func calculateLastIP(prefix netip.Prefix) netip.Addr {
|
func calculateLastIP(prefix netip.Prefix) netip.Addr {
|
||||||
hostMask := ^uint32(0) >> prefix.Masked().Bits()
|
masked := prefix.Masked()
|
||||||
lastIP := uint32FromNetipAddr(prefix.Addr()) | hostMask
|
if masked.Addr().Is4() {
|
||||||
|
hostMask := ^uint32(0) >> masked.Bits()
|
||||||
|
lastIP := uint32FromNetipAddr(masked.Addr()) | hostMask
|
||||||
|
return netip.AddrFrom4(uint32ToBytes(lastIP))
|
||||||
|
}
|
||||||
|
|
||||||
return netip.AddrFrom4(uint32ToBytes(lastIP))
|
// IPv6: set host bits to all 1s
|
||||||
|
b := masked.Addr().As16()
|
||||||
|
bits := masked.Bits()
|
||||||
|
for i := bits; i < 128; i++ {
|
||||||
|
b[i/8] |= 1 << (7 - i%8)
|
||||||
|
}
|
||||||
|
return netip.AddrFrom16(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Utility function to convert netip.Addr to uint32.
|
// Utility function to convert netip.Addr to uint32.
|
||||||
@@ -834,9 +859,12 @@ func (r *router) addPostroutingRules() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
||||||
// TODO: Add IPv6 support
|
|
||||||
func (r *router) addMSSClampingRules() error {
|
func (r *router) addMSSClampingRules() error {
|
||||||
mss := r.mtu - ipTCPHeaderMinSize
|
overhead := uint16(ipv4TCPHeaderSize)
|
||||||
|
if r.af.tableFamily == nftables.TableFamilyIPv6 {
|
||||||
|
overhead = ipv6TCPHeaderSize
|
||||||
|
}
|
||||||
|
mss := r.mtu - overhead
|
||||||
|
|
||||||
exprsOut := []expr.Any{
|
exprsOut := []expr.Any{
|
||||||
&expr.Meta{
|
&expr.Meta{
|
||||||
@@ -1043,17 +1071,22 @@ func (r *router) acceptFilterTableRules() error {
|
|||||||
log.Debugf("Used %s to add accept forward and input rules", fw)
|
log.Debugf("Used %s to add accept forward and input rules", fw)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Try iptables first and fallback to nftables if iptables is not available
|
// Try iptables first and fallback to nftables if iptables is not available.
|
||||||
ipt, err := iptables.New()
|
// Use the correct protocol (iptables vs ip6tables) for the address family.
|
||||||
|
ipt, err := iptables.NewWithProtocol(r.iptablesProto())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// iptables is not available but the filter table exists
|
|
||||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||||
|
|
||||||
fw = "nftables"
|
fw = "nftables"
|
||||||
return r.acceptFilterRulesNftables(r.filterTable)
|
return r.acceptFilterRulesNftables(r.filterTable)
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.acceptFilterRulesIptables(ipt)
|
if err := r.acceptFilterRulesIptables(ipt); err != nil {
|
||||||
|
log.Warnf("iptables failed (table may be incompatible), falling back to nftables: %v", err)
|
||||||
|
fw = "nftables"
|
||||||
|
return r.acceptFilterRulesNftables(r.filterTable)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||||
@@ -1222,13 +1255,17 @@ func (r *router) removeFilterTableRules() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ipt, err := iptables.New()
|
ipt, err := iptables.NewWithProtocol(r.iptablesProto())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("iptables not available, using nftables to remove filter rules: %v", err)
|
log.Debugf("iptables not available, using nftables to remove filter rules: %v", err)
|
||||||
return r.removeAcceptRulesFromTable(r.filterTable)
|
return r.removeAcceptRulesFromTable(r.filterTable)
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.removeAcceptFilterRulesIptables(ipt)
|
if err := r.removeAcceptFilterRulesIptables(ipt); err != nil {
|
||||||
|
log.Debugf("iptables removal failed (table may be incompatible), falling back to nftables: %v", err)
|
||||||
|
return r.removeAcceptRulesFromTable(r.filterTable)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error {
|
func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error {
|
||||||
@@ -1295,7 +1332,7 @@ func (r *router) removeExternalChainsRules() error {
|
|||||||
func (r *router) findExternalChains() []*nftables.Chain {
|
func (r *router) findExternalChains() []*nftables.Chain {
|
||||||
var chains []*nftables.Chain
|
var chains []*nftables.Chain
|
||||||
|
|
||||||
families := []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyINet}
|
families := []nftables.TableFamily{r.af.tableFamily, nftables.TableFamilyINet}
|
||||||
|
|
||||||
for _, family := range families {
|
for _, family := range families {
|
||||||
allChains, err := r.conn.ListChainsOfTableFamily(family)
|
allChains, err := r.conn.ListChainsOfTableFamily(family)
|
||||||
@@ -1319,8 +1356,8 @@ func (r *router) isExternalChain(chain *nftables.Chain) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Skip all iptables-managed tables in the ip family
|
// Skip iptables/ip6tables-managed tables (adding nft-native rules breaks iptables-save compat)
|
||||||
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
|
if (chain.Table.Family == nftables.TableFamilyIPv4 || chain.Table.Family == nftables.TableFamilyIPv6) && isIptablesTable(chain.Table.Name) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1461,7 +1498,7 @@ func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
|||||||
return rule, nil
|
return rule, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
protoNum, err := protoToInt(rule.Protocol)
|
protoNum, err := r.af.protoNum(rule.Protocol)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
||||||
}
|
}
|
||||||
@@ -1524,7 +1561,7 @@ func (r *router) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, rule
|
|||||||
dnatExprs = append(dnatExprs,
|
dnatExprs = append(dnatExprs,
|
||||||
&expr.NAT{
|
&expr.NAT{
|
||||||
Type: expr.NATTypeDestNAT,
|
Type: expr.NATTypeDestNAT,
|
||||||
Family: uint32(nftables.TableFamilyIPv4),
|
Family: uint32(r.af.tableFamily),
|
||||||
RegAddrMin: 1,
|
RegAddrMin: 1,
|
||||||
RegProtoMin: regProtoMin,
|
RegProtoMin: regProtoMin,
|
||||||
RegProtoMax: regProtoMax,
|
RegProtoMax: regProtoMax,
|
||||||
@@ -1620,7 +1657,7 @@ func (r *router) addXTablesRedirect(dnatExprs []expr.Any, ruleKey string, rule f
|
|||||||
dnatRule := &nftables.Rule{
|
dnatRule := &nftables.Rule{
|
||||||
Table: &nftables.Table{
|
Table: &nftables.Table{
|
||||||
Name: tableNat,
|
Name: tableNat,
|
||||||
Family: nftables.TableFamilyIPv4,
|
Family: r.af.tableFamily,
|
||||||
},
|
},
|
||||||
Chain: &nftables.Chain{
|
Chain: &nftables.Chain{
|
||||||
Name: chainNameNatPrerouting,
|
Name: chainNameNatPrerouting,
|
||||||
@@ -1655,8 +1692,8 @@ func (r *router) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleKey
|
|||||||
&expr.Payload{
|
&expr.Payload{
|
||||||
DestRegister: 1,
|
DestRegister: 1,
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
Offset: 16,
|
Offset: r.af.dstAddrOffset,
|
||||||
Len: 4,
|
Len: r.af.addrLen,
|
||||||
},
|
},
|
||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
Op: expr.CmpOpEq,
|
Op: expr.CmpOpEq,
|
||||||
@@ -1734,7 +1771,7 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
|||||||
return fmt.Errorf("get set %s: %w", set.HashedName(), err)
|
return fmt.Errorf("get set %s: %w", set.HashedName(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
elements := convertPrefixesToSet(prefixes)
|
elements := r.convertPrefixesToSet(prefixes)
|
||||||
if err := r.conn.SetAddElements(nfset, elements); err != nil {
|
if err := r.conn.SetAddElements(nfset, elements); err != nil {
|
||||||
return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err)
|
return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err)
|
||||||
}
|
}
|
||||||
@@ -1756,7 +1793,7 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
protoNum, err := protoToInt(protocol)
|
protoNum, err := r.af.protoNum(protocol)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("convert protocol to number: %w", err)
|
return fmt.Errorf("convert protocol to number: %w", err)
|
||||||
}
|
}
|
||||||
@@ -1787,7 +1824,11 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
|
bits := 32
|
||||||
|
if localAddr.Is6() {
|
||||||
|
bits = 128
|
||||||
|
}
|
||||||
|
exprs = append(exprs, r.applyPrefix(netip.PrefixFrom(localAddr, bits), false)...)
|
||||||
|
|
||||||
exprs = append(exprs,
|
exprs = append(exprs,
|
||||||
&expr.Immediate{
|
&expr.Immediate{
|
||||||
@@ -1800,7 +1841,7 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
|
|||||||
},
|
},
|
||||||
&expr.NAT{
|
&expr.NAT{
|
||||||
Type: expr.NATTypeDestNAT,
|
Type: expr.NATTypeDestNAT,
|
||||||
Family: uint32(nftables.TableFamilyIPv4),
|
Family: uint32(r.af.tableFamily),
|
||||||
RegAddrMin: 1,
|
RegAddrMin: 1,
|
||||||
RegProtoMin: 2,
|
RegProtoMin: 2,
|
||||||
RegProtoMax: 0,
|
RegProtoMax: 0,
|
||||||
@@ -1887,7 +1928,7 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol,
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
protoNum, err := protoToInt(protocol)
|
protoNum, err := r.af.protoNum(protocol)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("convert protocol to number: %w", err)
|
return fmt.Errorf("convert protocol to number: %w", err)
|
||||||
}
|
}
|
||||||
@@ -1912,7 +1953,11 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol,
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
|
bits := 32
|
||||||
|
if localAddr.Is6() {
|
||||||
|
bits = 128
|
||||||
|
}
|
||||||
|
exprs = append(exprs, r.applyPrefix(netip.PrefixFrom(localAddr, bits), false)...)
|
||||||
|
|
||||||
exprs = append(exprs,
|
exprs = append(exprs,
|
||||||
&expr.Immediate{
|
&expr.Immediate{
|
||||||
@@ -1925,7 +1970,7 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol,
|
|||||||
},
|
},
|
||||||
&expr.NAT{
|
&expr.NAT{
|
||||||
Type: expr.NATTypeDestNAT,
|
Type: expr.NATTypeDestNAT,
|
||||||
Family: uint32(nftables.TableFamilyIPv4),
|
Family: uint32(r.af.tableFamily),
|
||||||
RegAddrMin: 1,
|
RegAddrMin: 1,
|
||||||
RegProtoMin: 2,
|
RegProtoMin: 2,
|
||||||
},
|
},
|
||||||
@@ -1993,45 +2038,44 @@ func (r *router) applyNetwork(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if network.IsPrefix() {
|
if network.IsPrefix() {
|
||||||
return applyPrefix(network.Prefix, isSource), nil
|
return r.applyPrefix(network.Prefix, isSource), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// applyPrefix generates nftables expressions for a CIDR prefix
|
// applyPrefix generates nftables expressions for a CIDR prefix
|
||||||
func applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any {
|
func (r *router) applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any {
|
||||||
// dst offset
|
// dst offset by default
|
||||||
offset := uint32(16)
|
offset := r.af.dstAddrOffset
|
||||||
if isSource {
|
if isSource {
|
||||||
// src offset
|
// src offset
|
||||||
offset = 12
|
offset = r.af.srcAddrOffset
|
||||||
}
|
}
|
||||||
|
|
||||||
ones := prefix.Bits()
|
ones := prefix.Bits()
|
||||||
// 0.0.0.0/0 doesn't need extra expressions
|
// unspecified address (/0) doesn't need extra expressions
|
||||||
if ones == 0 {
|
if ones == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
mask := net.CIDRMask(ones, 32)
|
mask := net.CIDRMask(ones, r.af.totalBits)
|
||||||
|
xor := make([]byte, r.af.addrLen)
|
||||||
|
|
||||||
return []expr.Any{
|
return []expr.Any{
|
||||||
&expr.Payload{
|
&expr.Payload{
|
||||||
DestRegister: 1,
|
DestRegister: 1,
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
Offset: offset,
|
Offset: offset,
|
||||||
Len: 4,
|
Len: r.af.addrLen,
|
||||||
},
|
},
|
||||||
// netmask
|
|
||||||
&expr.Bitwise{
|
&expr.Bitwise{
|
||||||
DestRegister: 1,
|
DestRegister: 1,
|
||||||
SourceRegister: 1,
|
SourceRegister: 1,
|
||||||
Len: 4,
|
Len: r.af.addrLen,
|
||||||
Mask: mask,
|
Mask: mask,
|
||||||
Xor: []byte{0, 0, 0, 0},
|
Xor: xor,
|
||||||
},
|
},
|
||||||
// net address
|
|
||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
Op: expr.CmpOpEq,
|
Op: expr.CmpOpEq,
|
||||||
Register: 1,
|
Register: 1,
|
||||||
@@ -2114,13 +2158,12 @@ func getCtNewExprs() []expr.Any {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
|
func (r *router) getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
|
||||||
|
// dst offset by default
|
||||||
// dst offset
|
offset := r.af.dstAddrOffset
|
||||||
offset := uint32(16)
|
|
||||||
if isSource {
|
if isSource {
|
||||||
// src offset
|
// src offset
|
||||||
offset = 12
|
offset = r.af.srcAddrOffset
|
||||||
}
|
}
|
||||||
|
|
||||||
return []expr.Any{
|
return []expr.Any{
|
||||||
@@ -2128,7 +2171,7 @@ func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any
|
|||||||
DestRegister: 1,
|
DestRegister: 1,
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
Offset: offset,
|
Offset: offset,
|
||||||
Len: 4,
|
Len: r.af.addrLen,
|
||||||
},
|
},
|
||||||
&expr.Lookup{
|
&expr.Lookup{
|
||||||
SourceRegister: 1,
|
SourceRegister: 1,
|
||||||
|
|||||||
@@ -90,8 +90,9 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Build CIDR matching expressions
|
// Build CIDR matching expressions
|
||||||
sourceExp := applyPrefix(testCase.InputPair.Source.Prefix, true)
|
testRouter := &router{af: afIPv4}
|
||||||
destExp := applyPrefix(testCase.InputPair.Destination.Prefix, false)
|
sourceExp := testRouter.applyPrefix(testCase.InputPair.Source.Prefix, true)
|
||||||
|
destExp := testRouter.applyPrefix(testCase.InputPair.Destination.Prefix, false)
|
||||||
|
|
||||||
// Combine all expressions in the correct order
|
// Combine all expressions in the correct order
|
||||||
// nolint:gocritic
|
// nolint:gocritic
|
||||||
@@ -508,6 +509,136 @@ func TestNftablesCreateIpSet(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNftablesCreateIpSet_IPv6(t *testing.T) {
|
||||||
|
if check() != NFTABLES {
|
||||||
|
t.Skip("nftables not supported on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
workTable, err := createWorkTableIPv6()
|
||||||
|
require.NoError(t, err, "Failed to create v6 work table")
|
||||||
|
defer deleteWorkTableIPv6()
|
||||||
|
|
||||||
|
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err, "Failed to create router")
|
||||||
|
require.NoError(t, r.init(workTable))
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, r.Reset(), "Failed to reset router")
|
||||||
|
}()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sources []netip.Prefix
|
||||||
|
expected []netip.Prefix
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Single IPv6",
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("2001:db8::1/128")},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Multiple IPv6 Subnets",
|
||||||
|
sources: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("fd00::/64"),
|
||||||
|
netip.MustParsePrefix("2001:db8::/48"),
|
||||||
|
netip.MustParsePrefix("fe80::/10"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Overlapping IPv6",
|
||||||
|
sources: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("fd00::/48"),
|
||||||
|
netip.MustParsePrefix("fd00::/64"),
|
||||||
|
netip.MustParsePrefix("fd00::1/128"),
|
||||||
|
},
|
||||||
|
expected: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("fd00::/48"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Mixed prefix lengths",
|
||||||
|
sources: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("2001:db8:1::/48"),
|
||||||
|
netip.MustParsePrefix("2001:db8:2::1/128"),
|
||||||
|
netip.MustParsePrefix("fd00:abcd::/32"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
setName := firewall.NewPrefixSet(tt.sources).HashedName()
|
||||||
|
set, err := r.createIpSet(setName, setInput{prefixes: tt.sources})
|
||||||
|
require.NoError(t, err, "Failed to create IPv6 set")
|
||||||
|
require.NotNil(t, set)
|
||||||
|
|
||||||
|
assert.Equal(t, setName, set.Name)
|
||||||
|
assert.True(t, set.Interval)
|
||||||
|
assert.Equal(t, nftables.TypeIP6Addr, set.KeyType)
|
||||||
|
|
||||||
|
fetchedSet, err := r.conn.GetSetByName(r.workTable, setName)
|
||||||
|
require.NoError(t, err, "Failed to fetch created set")
|
||||||
|
|
||||||
|
elements, err := r.conn.GetSetElements(fetchedSet)
|
||||||
|
require.NoError(t, err, "Failed to get set elements")
|
||||||
|
|
||||||
|
uniquePrefixes := make(map[string]bool)
|
||||||
|
for _, elem := range elements {
|
||||||
|
if !elem.IntervalEnd && len(elem.Key) == 16 {
|
||||||
|
ip := netip.AddrFrom16([16]byte(elem.Key))
|
||||||
|
uniquePrefixes[ip.String()] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedCount := len(tt.expected)
|
||||||
|
if expectedCount == 0 {
|
||||||
|
expectedCount = len(tt.sources)
|
||||||
|
}
|
||||||
|
assert.Equal(t, expectedCount, len(uniquePrefixes), "unique prefix count mismatch")
|
||||||
|
|
||||||
|
r.conn.DelSet(set)
|
||||||
|
require.NoError(t, r.conn.Flush())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createWorkTableIPv6() (*nftables.Table, error) {
|
||||||
|
sConn, err := nftables.New(nftables.AsLasting())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv6)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, t := range tables {
|
||||||
|
if t.Name == tableNameNetbird {
|
||||||
|
sConn.DelTable(t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
table := sConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv6})
|
||||||
|
err = sConn.Flush()
|
||||||
|
return table, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func deleteWorkTableIPv6() {
|
||||||
|
sConn, err := nftables.New(nftables.AsLasting())
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv6)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, t := range tables {
|
||||||
|
if t.Name == tableNameNetbird {
|
||||||
|
sConn.DelTable(t)
|
||||||
|
_ = sConn.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) {
|
func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
@@ -627,7 +758,7 @@ func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
|
|||||||
|
|
||||||
func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool {
|
func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool {
|
||||||
var metaFound, cmpFound bool
|
var metaFound, cmpFound bool
|
||||||
expectedProto, _ := protoToInt(proto)
|
expectedProto, _ := afIPv4.protoNum(proto)
|
||||||
for _, e := range exprs {
|
for _, e := range exprs {
|
||||||
switch ex := e.(type) {
|
switch ex := e.(type) {
|
||||||
case *expr.Meta:
|
case *expr.Meta:
|
||||||
@@ -854,3 +985,55 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
|||||||
}
|
}
|
||||||
assert.Equal(t, 1, found, "NAT rule should exist in kernel")
|
assert.Equal(t, 1, found, "NAT rule should exist in kernel")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCalculateLastIP(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
prefix string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"10.0.0.0/24", "10.0.0.255"},
|
||||||
|
{"10.0.0.0/32", "10.0.0.0"},
|
||||||
|
{"0.0.0.0/0", "255.255.255.255"},
|
||||||
|
{"192.168.1.0/28", "192.168.1.15"},
|
||||||
|
{"fd00::/64", "fd00::ffff:ffff:ffff:ffff"},
|
||||||
|
{"fd00::/128", "fd00::"},
|
||||||
|
{"2001:db8::/48", "2001:db8:0:ffff:ffff:ffff:ffff:ffff"},
|
||||||
|
{"::/0", "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.prefix, func(t *testing.T) {
|
||||||
|
prefix := netip.MustParsePrefix(tt.prefix)
|
||||||
|
got := calculateLastIP(prefix)
|
||||||
|
assert.Equal(t, tt.want, got.String())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertPrefixesToSet_IPv6(t *testing.T) {
|
||||||
|
r := &router{af: afIPv6}
|
||||||
|
prefixes := []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("fd00::/64"),
|
||||||
|
netip.MustParsePrefix("2001:db8::1/128"),
|
||||||
|
}
|
||||||
|
|
||||||
|
elements := r.convertPrefixesToSet(prefixes)
|
||||||
|
|
||||||
|
// Each prefix produces 2 elements (start + end)
|
||||||
|
require.Len(t, elements, 4)
|
||||||
|
|
||||||
|
// fd00::/64 start
|
||||||
|
assert.Equal(t, netip.MustParseAddr("fd00::").As16(), [16]byte(elements[0].Key))
|
||||||
|
assert.False(t, elements[0].IntervalEnd)
|
||||||
|
|
||||||
|
// fd00::/64 end (fd00:0:0:1::, one past the last)
|
||||||
|
assert.Equal(t, netip.MustParseAddr("fd00:0:0:1::").As16(), [16]byte(elements[1].Key))
|
||||||
|
assert.True(t, elements[1].IntervalEnd)
|
||||||
|
|
||||||
|
// 2001:db8::1/128 start
|
||||||
|
assert.Equal(t, netip.MustParseAddr("2001:db8::1").As16(), [16]byte(elements[2].Key))
|
||||||
|
assert.False(t, elements[2].IntervalEnd)
|
||||||
|
|
||||||
|
// 2001:db8::1/128 end (2001:db8::2)
|
||||||
|
assert.Equal(t, netip.MustParseAddr("2001:db8::2").As16(), [16]byte(elements[3].Key))
|
||||||
|
assert.True(t, elements[3].IntervalEnd)
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,8 +5,10 @@ import (
|
|||||||
"os/exec"
|
"os/exec"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -29,15 +31,20 @@ func (m *Manager) Close(*statemanager.Manager) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if !isFirewallRuleActive(firewallRuleName) {
|
var merr *multierror.Error
|
||||||
return nil
|
if isFirewallRuleActive(firewallRuleName) {
|
||||||
|
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove windows firewall rule: %w", err))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
|
if isFirewallRuleActive(firewallRuleName + "-v6") {
|
||||||
return fmt.Errorf("couldn't remove windows firewall: %w", err)
|
if err := manageFirewallRule(firewallRuleName+"-v6", deleteRule); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove windows v6 firewall rule: %w", err))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AllowNetbird allows netbird interface traffic
|
// AllowNetbird allows netbird interface traffic
|
||||||
@@ -46,17 +53,33 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if isFirewallRuleActive(firewallRuleName) {
|
if !isFirewallRuleActive(firewallRuleName) {
|
||||||
return nil
|
if err := manageFirewallRule(firewallRuleName,
|
||||||
|
addRule,
|
||||||
|
"dir=in",
|
||||||
|
"enable=yes",
|
||||||
|
"action=allow",
|
||||||
|
"profile=any",
|
||||||
|
"localip="+m.wgIface.Address().IP.String(),
|
||||||
|
); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return manageFirewallRule(firewallRuleName,
|
|
||||||
addRule,
|
if v6 := m.wgIface.Address().IPv6; v6.IsValid() && !isFirewallRuleActive(firewallRuleName+"-v6") {
|
||||||
"dir=in",
|
if err := manageFirewallRule(firewallRuleName+"-v6",
|
||||||
"enable=yes",
|
addRule,
|
||||||
"action=allow",
|
"dir=in",
|
||||||
"profile=any",
|
"enable=yes",
|
||||||
"localip="+m.wgIface.Address().IP.String(),
|
"action=allow",
|
||||||
)
|
"profile=any",
|
||||||
|
"localip="+v6.String(),
|
||||||
|
); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func manageFirewallRule(ruleName string, action action, extraArgs ...string) error {
|
func manageFirewallRule(ruleName string, action action, extraArgs ...string) error {
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
package conntrack
|
package conntrack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"strconv"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -64,5 +65,7 @@ type ConnKey struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c ConnKey) String() string {
|
func (c ConnKey) String() string {
|
||||||
return fmt.Sprintf("%s:%d → %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
|
return net.JoinHostPort(c.SrcIP.Unmap().String(), strconv.Itoa(int(c.SrcPort))) +
|
||||||
|
" → " +
|
||||||
|
net.JoinHostPort(c.DstIP.Unmap().String(), strconv.Itoa(int(c.DstPort)))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,9 +21,10 @@ const (
|
|||||||
// ICMPCleanupInterval is how often we check for stale ICMP connections
|
// ICMPCleanupInterval is how often we check for stale ICMP connections
|
||||||
ICMPCleanupInterval = 15 * time.Second
|
ICMPCleanupInterval = 15 * time.Second
|
||||||
|
|
||||||
// MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info,
|
// MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info.
|
||||||
// which includes the IP header (20 bytes) and transport header (8 bytes)
|
// IPv4: 20-byte header + 8-byte transport = 28 bytes.
|
||||||
MaxICMPPayloadLength = 28
|
// IPv6: 40-byte header + 8-byte transport = 48 bytes.
|
||||||
|
MaxICMPPayloadLength = 48
|
||||||
)
|
)
|
||||||
|
|
||||||
// ICMPConnKey uniquely identifies an ICMP connection
|
// ICMPConnKey uniquely identifies an ICMP connection
|
||||||
@@ -74,32 +75,64 @@ func (info ICMPInfo) String() string {
|
|||||||
return info.TypeCode.String()
|
return info.TypeCode.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// isErrorMessage returns true if this ICMP type carries original packet info
|
// isErrorMessage returns true if this ICMP type carries original packet info.
|
||||||
|
// Covers both ICMPv4 and ICMPv6 error types. Without a family field we match
|
||||||
|
// both sets; type 3 overlaps (v4 DestUnreachable / v6 TimeExceeded) so it's
|
||||||
|
// kept as a literal.
|
||||||
func (info ICMPInfo) isErrorMessage() bool {
|
func (info ICMPInfo) isErrorMessage() bool {
|
||||||
typ := info.TypeCode.Type()
|
typ := info.TypeCode.Type()
|
||||||
return typ == 3 || // Destination Unreachable
|
// ICMPv4 error types
|
||||||
typ == 5 || // Redirect
|
if typ == layers.ICMPv4TypeDestinationUnreachable ||
|
||||||
typ == 11 || // Time Exceeded
|
typ == layers.ICMPv4TypeRedirect ||
|
||||||
typ == 12 // Parameter Problem
|
typ == layers.ICMPv4TypeTimeExceeded ||
|
||||||
|
typ == layers.ICMPv4TypeParameterProblem {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// ICMPv6 error types (type 3 already matched above as v4 DestUnreachable)
|
||||||
|
if typ == layers.ICMPv6TypeDestinationUnreachable ||
|
||||||
|
typ == layers.ICMPv6TypePacketTooBig ||
|
||||||
|
typ == layers.ICMPv6TypeParameterProblem {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseOriginalPacket extracts info about the original packet from ICMP payload
|
// parseOriginalPacket extracts info about the original packet from ICMP payload
|
||||||
func (info ICMPInfo) parseOriginalPacket() string {
|
func (info ICMPInfo) parseOriginalPacket() string {
|
||||||
if info.PayloadLen < MaxICMPPayloadLength {
|
if info.PayloadLen == 0 {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: handle IPv6
|
version := (info.PayloadData[0] >> 4) & 0xF
|
||||||
if version := (info.PayloadData[0] >> 4) & 0xF; version != 4 {
|
|
||||||
|
var protocol uint8
|
||||||
|
var srcIP, dstIP net.IP
|
||||||
|
var transportData []byte
|
||||||
|
|
||||||
|
switch version {
|
||||||
|
case 4:
|
||||||
|
// 20-byte IPv4 header + 8-byte transport minimum
|
||||||
|
if info.PayloadLen < 28 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
protocol = info.PayloadData[9]
|
||||||
|
srcIP = net.IP(info.PayloadData[12:16])
|
||||||
|
dstIP = net.IP(info.PayloadData[16:20])
|
||||||
|
transportData = info.PayloadData[20:]
|
||||||
|
case 6:
|
||||||
|
// 40-byte IPv6 header + 8-byte transport minimum
|
||||||
|
if info.PayloadLen < 48 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
// Next Header field in IPv6 header
|
||||||
|
protocol = info.PayloadData[6]
|
||||||
|
srcIP = net.IP(info.PayloadData[8:24])
|
||||||
|
dstIP = net.IP(info.PayloadData[24:40])
|
||||||
|
transportData = info.PayloadData[40:]
|
||||||
|
default:
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
protocol := info.PayloadData[9]
|
|
||||||
srcIP := net.IP(info.PayloadData[12:16])
|
|
||||||
dstIP := net.IP(info.PayloadData[16:20])
|
|
||||||
|
|
||||||
transportData := info.PayloadData[20:]
|
|
||||||
|
|
||||||
switch nftypes.Protocol(protocol) {
|
switch nftypes.Protocol(protocol) {
|
||||||
case nftypes.TCP:
|
case nftypes.TCP:
|
||||||
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
|
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
|
||||||
@@ -247,9 +280,10 @@ func (t *ICMPTracker) track(
|
|||||||
t.sendEvent(nftypes.TypeStart, conn, ruleId)
|
t.sendEvent(nftypes.TypeStart, conn, ruleId)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request.
|
||||||
|
// Accepts both ICMPv4 (type 0) and ICMPv6 (type 129) echo replies.
|
||||||
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool {
|
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool {
|
||||||
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
|
if icmpType != uint8(layers.ICMPv4TypeEchoReply) && icmpType != uint8(layers.ICMPv6TypeEchoReply) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -301,6 +335,13 @@ func (t *ICMPTracker) cleanup() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func icmpProtocolForAddr(ip netip.Addr) nftypes.Protocol {
|
||||||
|
if ip.Is6() {
|
||||||
|
return nftypes.ICMPv6
|
||||||
|
}
|
||||||
|
return nftypes.ICMP
|
||||||
|
}
|
||||||
|
|
||||||
// Close stops the cleanup routine and releases resources
|
// Close stops the cleanup routine and releases resources
|
||||||
func (t *ICMPTracker) Close() {
|
func (t *ICMPTracker) Close() {
|
||||||
t.tickerCancel()
|
t.tickerCancel()
|
||||||
@@ -316,7 +357,7 @@ func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack, ruleID []
|
|||||||
Type: typ,
|
Type: typ,
|
||||||
RuleID: ruleID,
|
RuleID: ruleID,
|
||||||
Direction: conn.Direction,
|
Direction: conn.Direction,
|
||||||
Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
|
Protocol: icmpProtocolForAddr(conn.SourceIP),
|
||||||
SourceIP: conn.SourceIP,
|
SourceIP: conn.SourceIP,
|
||||||
DestIP: conn.DestIP,
|
DestIP: conn.DestIP,
|
||||||
ICMPType: conn.ICMPType,
|
ICMPType: conn.ICMPType,
|
||||||
@@ -334,7 +375,7 @@ func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Ad
|
|||||||
Type: nftypes.TypeStart,
|
Type: nftypes.TypeStart,
|
||||||
RuleID: ruleID,
|
RuleID: ruleID,
|
||||||
Direction: direction,
|
Direction: direction,
|
||||||
Protocol: nftypes.ICMP,
|
Protocol: icmpProtocolForAddr(srcIP),
|
||||||
SourceIP: srcIP,
|
SourceIP: srcIP,
|
||||||
DestIP: dstIP,
|
DestIP: dstIP,
|
||||||
ICMPType: typ,
|
ICMPType: typ,
|
||||||
|
|||||||
@@ -35,8 +35,10 @@ import (
|
|||||||
const (
|
const (
|
||||||
layerTypeAll = 255
|
layerTypeAll = 255
|
||||||
|
|
||||||
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
// ipv4TCPHeaderMinSize represents minimum IPv4 (20) + TCP (20) header size for MSS calculation
|
||||||
ipTCPHeaderMinSize = 40
|
ipv4TCPHeaderMinSize = 40
|
||||||
|
// ipv6TCPHeaderMinSize represents minimum IPv6 (40) + TCP (20) header size for MSS calculation
|
||||||
|
ipv6TCPHeaderMinSize = 60
|
||||||
)
|
)
|
||||||
|
|
||||||
// serviceKey represents a protocol/port combination for netstack service registry
|
// serviceKey represents a protocol/port combination for netstack service registry
|
||||||
@@ -137,9 +139,10 @@ type Manager struct {
|
|||||||
netstackServices map[serviceKey]struct{}
|
netstackServices map[serviceKey]struct{}
|
||||||
netstackServiceMutex sync.RWMutex
|
netstackServiceMutex sync.RWMutex
|
||||||
|
|
||||||
mtu uint16
|
mtu uint16
|
||||||
mssClampValue uint16
|
mssClampValueIPv4 uint16
|
||||||
mssClampEnabled bool
|
mssClampValueIPv6 uint16
|
||||||
|
mssClampEnabled bool
|
||||||
|
|
||||||
// Only one hook per protocol is supported. Outbound direction only.
|
// Only one hook per protocol is supported. Outbound direction only.
|
||||||
udpHookOut atomic.Pointer[packetHook]
|
udpHookOut atomic.Pointer[packetHook]
|
||||||
@@ -163,11 +166,28 @@ type decoder struct {
|
|||||||
icmp4 layers.ICMPv4
|
icmp4 layers.ICMPv4
|
||||||
icmp6 layers.ICMPv6
|
icmp6 layers.ICMPv6
|
||||||
decoded []gopacket.LayerType
|
decoded []gopacket.LayerType
|
||||||
parser *gopacket.DecodingLayerParser
|
parser4 *gopacket.DecodingLayerParser
|
||||||
|
parser6 *gopacket.DecodingLayerParser
|
||||||
|
|
||||||
dnatOrigPort uint16
|
dnatOrigPort uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// decodePacket decodes packet data using the appropriate parser based on IP version.
|
||||||
|
func (d *decoder) decodePacket(data []byte) error {
|
||||||
|
if len(data) == 0 {
|
||||||
|
return errors.New("empty packet")
|
||||||
|
}
|
||||||
|
version := data[0] >> 4
|
||||||
|
switch version {
|
||||||
|
case 4:
|
||||||
|
return d.parser4.DecodeLayers(data, &d.decoded)
|
||||||
|
case 6:
|
||||||
|
return d.parser6.DecodeLayers(data, &d.decoded)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown IP version %d", version)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Create userspace firewall manager constructor
|
// Create userspace firewall manager constructor
|
||||||
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
|
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
|
||||||
return create(iface, nil, disableServerRoutes, flowLogger, mtu)
|
return create(iface, nil, disableServerRoutes, flowLogger, mtu)
|
||||||
@@ -225,11 +245,17 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
d := &decoder{
|
d := &decoder{
|
||||||
decoded: []gopacket.LayerType{},
|
decoded: []gopacket.LayerType{},
|
||||||
}
|
}
|
||||||
d.parser = gopacket.NewDecodingLayerParser(
|
d.parser4 = gopacket.NewDecodingLayerParser(
|
||||||
layers.LayerTypeIPv4,
|
layers.LayerTypeIPv4,
|
||||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
)
|
)
|
||||||
d.parser.IgnoreUnsupported = true
|
d.parser4.IgnoreUnsupported = true
|
||||||
|
|
||||||
|
d.parser6 = gopacket.NewDecodingLayerParser(
|
||||||
|
layers.LayerTypeIPv6,
|
||||||
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
|
)
|
||||||
|
d.parser6.IgnoreUnsupported = true
|
||||||
return d
|
return d
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -255,7 +281,8 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
|
|
||||||
if !disableMSSClamping {
|
if !disableMSSClamping {
|
||||||
m.mssClampEnabled = true
|
m.mssClampEnabled = true
|
||||||
m.mssClampValue = mtu - ipTCPHeaderMinSize
|
m.mssClampValueIPv4 = mtu - ipv4TCPHeaderMinSize
|
||||||
|
m.mssClampValueIPv6 = mtu - ipv6TCPHeaderMinSize
|
||||||
}
|
}
|
||||||
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
||||||
return nil, fmt.Errorf("update local IPs: %w", err)
|
return nil, fmt.Errorf("update local IPs: %w", err)
|
||||||
@@ -282,9 +309,14 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, e
|
|||||||
wgPrefix := iface.Address().Network
|
wgPrefix := iface.Address().Network
|
||||||
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
||||||
|
|
||||||
|
sources := []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||||
|
if v6 := iface.Address().IPv6Net; v6.IsValid() {
|
||||||
|
sources = append(sources, netip.PrefixFrom(netip.IPv6Unspecified(), 0))
|
||||||
|
}
|
||||||
|
|
||||||
rule, err := m.addRouteFiltering(
|
rule, err := m.addRouteFiltering(
|
||||||
nil,
|
nil,
|
||||||
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
|
sources,
|
||||||
firewall.Network{Prefix: wgPrefix},
|
firewall.Network{Prefix: wgPrefix},
|
||||||
firewall.ProtocolALL,
|
firewall.ProtocolALL,
|
||||||
nil,
|
nil,
|
||||||
@@ -292,7 +324,22 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, e
|
|||||||
firewall.ActionDrop,
|
firewall.ActionDrop,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("block wg nte : %w", err)
|
return nil, fmt.Errorf("block wg v4 net: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if v6Net := iface.Address().IPv6Net; v6Net.IsValid() {
|
||||||
|
log.Debugf("blocking invalid routed traffic for %s", v6Net)
|
||||||
|
if _, err := m.addRouteFiltering(
|
||||||
|
nil,
|
||||||
|
sources,
|
||||||
|
firewall.Network{Prefix: v6Net},
|
||||||
|
firewall.ProtocolALL,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
firewall.ActionDrop,
|
||||||
|
); err != nil {
|
||||||
|
return nil, fmt.Errorf("block wg v6 net: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Block networks that we're a client of
|
// TODO: Block networks that we're a client of
|
||||||
@@ -509,7 +556,7 @@ func (m *Manager) addRouteFiltering(
|
|||||||
mgmtId: id,
|
mgmtId: id,
|
||||||
sources: sources,
|
sources: sources,
|
||||||
dstSet: destination.Set,
|
dstSet: destination.Set,
|
||||||
protoLayer: protoToLayer(proto, layers.LayerTypeIPv4),
|
protoLayer: protoToLayer(proto, ipLayerFromPrefix(destination.Prefix)),
|
||||||
srcPort: sPort,
|
srcPort: sPort,
|
||||||
dstPort: dPort,
|
dstPort: dPort,
|
||||||
action: action,
|
action: action,
|
||||||
@@ -663,11 +710,7 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
destinations := matches[0].destinations
|
destinations := matches[0].destinations
|
||||||
for _, prefix := range prefixes {
|
destinations = append(destinations, prefixes...)
|
||||||
if prefix.Addr().Is4() {
|
|
||||||
destinations = append(destinations, prefix)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
slices.SortFunc(destinations, func(a, b netip.Prefix) int {
|
slices.SortFunc(destinations, func(a, b netip.Prefix) int {
|
||||||
cmp := a.Addr().Compare(b.Addr())
|
cmp := a.Addr().Compare(b.Addr())
|
||||||
@@ -706,7 +749,7 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
|||||||
d := m.decoders.Get().(*decoder)
|
d := m.decoders.Get().(*decoder)
|
||||||
defer m.decoders.Put(d)
|
defer m.decoders.Put(d)
|
||||||
|
|
||||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
if err := d.decodePacket(packetData); err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -790,12 +833,28 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var mssClampValue uint16
|
||||||
|
var ipHeaderSize int
|
||||||
|
switch d.decoded[0] {
|
||||||
|
case layers.LayerTypeIPv4:
|
||||||
|
mssClampValue = m.mssClampValueIPv4
|
||||||
|
ipHeaderSize = int(d.ip4.IHL) * 4
|
||||||
|
if ipHeaderSize < 20 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
case layers.LayerTypeIPv6:
|
||||||
|
mssClampValue = m.mssClampValueIPv6
|
||||||
|
ipHeaderSize = 40
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
mssOptionIndex := -1
|
mssOptionIndex := -1
|
||||||
var currentMSS uint16
|
var currentMSS uint16
|
||||||
for i, opt := range d.tcp.Options {
|
for i, opt := range d.tcp.Options {
|
||||||
if opt.OptionType == layers.TCPOptionKindMSS && len(opt.OptionData) == 2 {
|
if opt.OptionType == layers.TCPOptionKindMSS && len(opt.OptionData) == 2 {
|
||||||
currentMSS = binary.BigEndian.Uint16(opt.OptionData)
|
currentMSS = binary.BigEndian.Uint16(opt.OptionData)
|
||||||
if currentMSS > m.mssClampValue {
|
if currentMSS > mssClampValue {
|
||||||
mssOptionIndex = i
|
mssOptionIndex = i
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -806,20 +865,15 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
ipHeaderSize := int(d.ip4.IHL) * 4
|
if !m.updateMSSOption(packetData, d, mssOptionIndex, mssClampValue, ipHeaderSize) {
|
||||||
if ipHeaderSize < 20 {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if !m.updateMSSOption(packetData, d, mssOptionIndex, ipHeaderSize) {
|
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, mssClampValue)
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue)
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex, ipHeaderSize int) bool {
|
func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex int, mssClampValue uint16, ipHeaderSize int) bool {
|
||||||
tcpHeaderStart := ipHeaderSize
|
tcpHeaderStart := ipHeaderSize
|
||||||
tcpOptionsStart := tcpHeaderStart + 20
|
tcpOptionsStart := tcpHeaderStart + 20
|
||||||
|
|
||||||
@@ -834,7 +888,7 @@ func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex,
|
|||||||
}
|
}
|
||||||
|
|
||||||
mssValueOffset := optOffset + 2
|
mssValueOffset := optOffset + 2
|
||||||
binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], m.mssClampValue)
|
binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], mssClampValue)
|
||||||
|
|
||||||
m.recalculateTCPChecksum(packetData, d, tcpHeaderStart)
|
m.recalculateTCPChecksum(packetData, d, tcpHeaderStart)
|
||||||
return true
|
return true
|
||||||
@@ -844,18 +898,32 @@ func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeade
|
|||||||
tcpLayer := packetData[tcpHeaderStart:]
|
tcpLayer := packetData[tcpHeaderStart:]
|
||||||
tcpLength := len(packetData) - tcpHeaderStart
|
tcpLength := len(packetData) - tcpHeaderStart
|
||||||
|
|
||||||
|
// Zero out existing checksum
|
||||||
tcpLayer[16] = 0
|
tcpLayer[16] = 0
|
||||||
tcpLayer[17] = 0
|
tcpLayer[17] = 0
|
||||||
|
|
||||||
|
// Build pseudo-header checksum based on IP version
|
||||||
var pseudoSum uint32
|
var pseudoSum uint32
|
||||||
pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1])
|
switch d.decoded[0] {
|
||||||
pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3])
|
case layers.LayerTypeIPv4:
|
||||||
pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1])
|
pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1])
|
||||||
pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3])
|
pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3])
|
||||||
pseudoSum += uint32(d.ip4.Protocol)
|
pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1])
|
||||||
pseudoSum += uint32(tcpLength)
|
pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3])
|
||||||
|
pseudoSum += uint32(d.ip4.Protocol)
|
||||||
|
pseudoSum += uint32(tcpLength)
|
||||||
|
case layers.LayerTypeIPv6:
|
||||||
|
for i := 0; i < 16; i += 2 {
|
||||||
|
pseudoSum += uint32(d.ip6.SrcIP[i])<<8 | uint32(d.ip6.SrcIP[i+1])
|
||||||
|
}
|
||||||
|
for i := 0; i < 16; i += 2 {
|
||||||
|
pseudoSum += uint32(d.ip6.DstIP[i])<<8 | uint32(d.ip6.DstIP[i+1])
|
||||||
|
}
|
||||||
|
pseudoSum += uint32(tcpLength)
|
||||||
|
pseudoSum += uint32(layers.IPProtocolTCP)
|
||||||
|
}
|
||||||
|
|
||||||
var sum = pseudoSum
|
sum := pseudoSum
|
||||||
for i := 0; i < tcpLength-1; i += 2 {
|
for i := 0; i < tcpLength-1; i += 2 {
|
||||||
sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1])
|
sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1])
|
||||||
}
|
}
|
||||||
@@ -893,6 +961,9 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData
|
|||||||
}
|
}
|
||||||
case layers.LayerTypeICMPv4:
|
case layers.LayerTypeICMPv4:
|
||||||
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
|
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
|
||||||
|
case layers.LayerTypeICMPv6:
|
||||||
|
id, tc := icmpv6EchoFields(d)
|
||||||
|
m.icmpTracker.TrackOutbound(srcIP, dstIP, id, tc, d.icmp6.Payload, size)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -906,6 +977,9 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
|
|||||||
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort)
|
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort)
|
||||||
case layers.LayerTypeICMPv4:
|
case layers.LayerTypeICMPv4:
|
||||||
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
|
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
|
||||||
|
case layers.LayerTypeICMPv6:
|
||||||
|
id, tc := icmpv6EchoFields(d)
|
||||||
|
m.icmpTracker.TrackInbound(srcIP, dstIP, id, tc, ruleID, d.icmp6.Payload, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
d.dnatOrigPort = 0
|
d.dnatOrigPort = 0
|
||||||
@@ -948,15 +1022,19 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
|
|||||||
|
|
||||||
// TODO: pass fragments of routed packets to forwarder
|
// TODO: pass fragments of routed packets to forwarder
|
||||||
if fragment {
|
if fragment {
|
||||||
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v",
|
if d.decoded[0] == layers.LayerTypeIPv4 {
|
||||||
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
|
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v",
|
||||||
|
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
|
||||||
|
} else {
|
||||||
|
m.logger.Trace2("packet is an IPv6 fragment: src=%v dst=%v", srcIP, dstIP)
|
||||||
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: optimize port DNAT by caching matched rules in conntrack
|
// TODO: optimize port DNAT by caching matched rules in conntrack
|
||||||
if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated {
|
if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated {
|
||||||
// Re-decode after port DNAT translation to update port information
|
// Re-decode after port DNAT translation to update port information
|
||||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
if err := d.decodePacket(packetData); err != nil {
|
||||||
m.logger.Error1("failed to re-decode packet after port DNAT: %v", err)
|
m.logger.Error1("failed to re-decode packet after port DNAT: %v", err)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -965,7 +1043,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
|
|||||||
|
|
||||||
if translated := m.translateInboundReverse(packetData, d); translated {
|
if translated := m.translateInboundReverse(packetData, d); translated {
|
||||||
// Re-decode after translation to get original addresses
|
// Re-decode after translation to get original addresses
|
||||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
if err := d.decodePacket(packetData); err != nil {
|
||||||
m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err)
|
m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -1097,6 +1175,48 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// icmpv6EchoFields extracts the echo identifier from an ICMPv6 packet and maps
|
||||||
|
// the ICMPv6 type code to an ICMPv4TypeCode so the ICMP conntrack can handle
|
||||||
|
// both families uniformly. The echo ID is in the first two payload bytes.
|
||||||
|
func icmpv6EchoFields(d *decoder) (id uint16, tc layers.ICMPv4TypeCode) {
|
||||||
|
if len(d.icmp6.Payload) >= 2 {
|
||||||
|
id = uint16(d.icmp6.Payload[0])<<8 | uint16(d.icmp6.Payload[1])
|
||||||
|
}
|
||||||
|
// Map ICMPv6 echo types to ICMPv4 equivalents for unified tracking.
|
||||||
|
switch d.icmp6.TypeCode.Type() {
|
||||||
|
case layers.ICMPv6TypeEchoRequest:
|
||||||
|
tc = layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0)
|
||||||
|
case layers.ICMPv6TypeEchoReply:
|
||||||
|
tc = layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoReply, 0)
|
||||||
|
default:
|
||||||
|
tc = layers.CreateICMPv4TypeCode(d.icmp6.TypeCode.Type(), d.icmp6.TypeCode.Code())
|
||||||
|
}
|
||||||
|
return id, tc
|
||||||
|
}
|
||||||
|
|
||||||
|
// protoLayerMatches checks if a packet's protocol layer matches a rule's expected
|
||||||
|
// protocol layer. ICMPv4 and ICMPv6 are treated as equivalent when matching
|
||||||
|
// ICMP rules since management sends a single ICMP rule for both families.
|
||||||
|
func protoLayerMatches(ruleLayer, packetLayer gopacket.LayerType) bool {
|
||||||
|
if ruleLayer == packetLayer {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if ruleLayer == layers.LayerTypeICMPv4 && packetLayer == layers.LayerTypeICMPv6 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if ruleLayer == layers.LayerTypeICMPv6 && packetLayer == layers.LayerTypeICMPv4 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func ipLayerFromPrefix(p netip.Prefix) gopacket.LayerType {
|
||||||
|
if p.Addr().Is6() {
|
||||||
|
return layers.LayerTypeIPv6
|
||||||
|
}
|
||||||
|
return layers.LayerTypeIPv4
|
||||||
|
}
|
||||||
|
|
||||||
func protoToLayer(proto firewall.Protocol, ipLayer gopacket.LayerType) gopacket.LayerType {
|
func protoToLayer(proto firewall.Protocol, ipLayer gopacket.LayerType) gopacket.LayerType {
|
||||||
switch proto {
|
switch proto {
|
||||||
case firewall.ProtocolTCP:
|
case firewall.ProtocolTCP:
|
||||||
@@ -1120,8 +1240,10 @@ func getProtocolFromPacket(d *decoder) nftypes.Protocol {
|
|||||||
return nftypes.TCP
|
return nftypes.TCP
|
||||||
case layers.LayerTypeUDP:
|
case layers.LayerTypeUDP:
|
||||||
return nftypes.UDP
|
return nftypes.UDP
|
||||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
case layers.LayerTypeICMPv4:
|
||||||
return nftypes.ICMP
|
return nftypes.ICMP
|
||||||
|
case layers.LayerTypeICMPv6:
|
||||||
|
return nftypes.ICMPv6
|
||||||
default:
|
default:
|
||||||
return nftypes.ProtocolUnknown
|
return nftypes.ProtocolUnknown
|
||||||
}
|
}
|
||||||
@@ -1142,7 +1264,7 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
|
|||||||
// It returns true, false if the packet is valid and not a fragment.
|
// It returns true, false if the packet is valid and not a fragment.
|
||||||
// It returns true, true if the packet is a fragment and valid.
|
// It returns true, true if the packet is a fragment and valid.
|
||||||
func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
|
func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
|
||||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
if err := d.decodePacket(packetData); err != nil {
|
||||||
m.logger.Trace1("couldn't decode packet, err: %s", err)
|
m.logger.Trace1("couldn't decode packet, err: %s", err)
|
||||||
return false, false
|
return false, false
|
||||||
}
|
}
|
||||||
@@ -1155,10 +1277,18 @@ func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Fragments are also valid
|
// Fragments are also valid
|
||||||
if l == 1 && d.decoded[0] == layers.LayerTypeIPv4 {
|
if l == 1 {
|
||||||
ip4 := d.ip4
|
switch d.decoded[0] {
|
||||||
if ip4.Flags&layers.IPv4MoreFragments != 0 || ip4.FragOffset != 0 {
|
case layers.LayerTypeIPv4:
|
||||||
return true, true
|
if d.ip4.Flags&layers.IPv4MoreFragments != 0 || d.ip4.FragOffset != 0 {
|
||||||
|
return true, true
|
||||||
|
}
|
||||||
|
case layers.LayerTypeIPv6:
|
||||||
|
// IPv6 uses Fragment extension header (NextHeader=44). If gopacket
|
||||||
|
// only decoded the IPv6 layer, the transport is in a fragment.
|
||||||
|
if d.ip6.NextHeader == layers.IPProtocolIPv6Fragment {
|
||||||
|
return true, true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1196,21 +1326,34 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr,
|
|||||||
size,
|
size,
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: ICMPv6
|
case layers.LayerTypeICMPv6:
|
||||||
|
id, _ := icmpv6EchoFields(d)
|
||||||
|
return m.icmpTracker.IsValidInbound(
|
||||||
|
srcIP,
|
||||||
|
dstIP,
|
||||||
|
id,
|
||||||
|
d.icmp6.TypeCode.Type(),
|
||||||
|
size,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// isSpecialICMP returns true if the packet is a special ICMP packet that should be allowed
|
// isSpecialICMP returns true if the packet is a special ICMP error packet that should be allowed.
|
||||||
func (m *Manager) isSpecialICMP(d *decoder) bool {
|
func (m *Manager) isSpecialICMP(d *decoder) bool {
|
||||||
if d.decoded[1] != layers.LayerTypeICMPv4 {
|
switch d.decoded[1] {
|
||||||
return false
|
case layers.LayerTypeICMPv4:
|
||||||
|
icmpType := d.icmp4.TypeCode.Type()
|
||||||
|
return icmpType == layers.ICMPv4TypeDestinationUnreachable ||
|
||||||
|
icmpType == layers.ICMPv4TypeTimeExceeded
|
||||||
|
case layers.LayerTypeICMPv6:
|
||||||
|
icmpType := d.icmp6.TypeCode.Type()
|
||||||
|
return icmpType == layers.ICMPv6TypeDestinationUnreachable ||
|
||||||
|
icmpType == layers.ICMPv6TypePacketTooBig ||
|
||||||
|
icmpType == layers.ICMPv6TypeTimeExceeded
|
||||||
}
|
}
|
||||||
|
return false
|
||||||
icmpType := d.icmp4.TypeCode.Type()
|
|
||||||
return icmpType == layers.ICMPv4TypeDestinationUnreachable ||
|
|
||||||
icmpType == layers.ICMPv4TypeTimeExceeded
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) peerACLsBlock(srcIP netip.Addr, d *decoder, packetData []byte) ([]byte, bool) {
|
func (m *Manager) peerACLsBlock(srcIP netip.Addr, d *decoder, packetData []byte) ([]byte, bool) {
|
||||||
@@ -1267,7 +1410,7 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
|
|||||||
return rule.mgmtId, rule.drop, true
|
return rule.mgmtId, rule.drop, true
|
||||||
}
|
}
|
||||||
|
|
||||||
if payloadLayer != rule.protoLayer {
|
if !protoLayerMatches(rule.protoLayer, payloadLayer) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1302,8 +1445,7 @@ func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.Lay
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) bool {
|
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) bool {
|
||||||
// TODO: handle ipv6 vs ipv4 icmp rules
|
if rule.protoLayer != layerTypeAll && !protoLayerMatches(rule.protoLayer, protoLayer) {
|
||||||
if rule.protoLayer != layerTypeAll && rule.protoLayer != protoLayer {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1473,7 +1615,8 @@ func (m *Manager) shouldForward(d *decoder, dstIP netip.Addr) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// traffic to our other local interfaces (not NetBird IP) - always forward
|
// traffic to our other local interfaces (not NetBird IP) - always forward
|
||||||
if dstIP != m.wgIface.Address().IP {
|
addr := m.wgIface.Address()
|
||||||
|
if dstIP != addr.IP && (!addr.IPv6.IsValid() || dstIP != addr.IPv6) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1023,7 +1023,8 @@ func BenchmarkMSSClamping(b *testing.B) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
manager.mssClampEnabled = true
|
manager.mssClampEnabled = true
|
||||||
manager.mssClampValue = 1240
|
manager.mssClampValueIPv4 = 1240
|
||||||
|
manager.mssClampValueIPv6 = 1220
|
||||||
|
|
||||||
srcIP := net.ParseIP("100.64.0.2")
|
srcIP := net.ParseIP("100.64.0.2")
|
||||||
dstIP := net.ParseIP("8.8.8.8")
|
dstIP := net.ParseIP("8.8.8.8")
|
||||||
@@ -1088,7 +1089,8 @@ func BenchmarkMSSClampingOverhead(b *testing.B) {
|
|||||||
|
|
||||||
manager.mssClampEnabled = sc.enabled
|
manager.mssClampEnabled = sc.enabled
|
||||||
if sc.enabled {
|
if sc.enabled {
|
||||||
manager.mssClampValue = 1240
|
manager.mssClampValueIPv4 = 1240
|
||||||
|
manager.mssClampValueIPv6 = 1220
|
||||||
}
|
}
|
||||||
|
|
||||||
srcIP := net.ParseIP("100.64.0.2")
|
srcIP := net.ParseIP("100.64.0.2")
|
||||||
@@ -1141,7 +1143,8 @@ func BenchmarkMSSClampingMemory(b *testing.B) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
manager.mssClampEnabled = true
|
manager.mssClampEnabled = true
|
||||||
manager.mssClampValue = 1240
|
manager.mssClampValueIPv4 = 1240
|
||||||
|
manager.mssClampValueIPv6 = 1220
|
||||||
|
|
||||||
srcIP := net.ParseIP("100.64.0.2")
|
srcIP := net.ParseIP("100.64.0.2")
|
||||||
dstIP := net.ParseIP("8.8.8.8")
|
dstIP := net.ParseIP("8.8.8.8")
|
||||||
|
|||||||
@@ -539,53 +539,236 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPeerACLFilteringIPv6(t *testing.T) {
|
||||||
|
localIP := netip.MustParseAddr("100.10.0.100")
|
||||||
|
localIPv6 := netip.MustParseAddr("fd00::100")
|
||||||
|
wgNet := netip.MustParsePrefix("100.10.0.0/16")
|
||||||
|
wgNetV6 := netip.MustParsePrefix("fd00::/64")
|
||||||
|
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
AddressFunc: func() wgaddr.Address {
|
||||||
|
return wgaddr.Address{
|
||||||
|
IP: localIP,
|
||||||
|
Network: wgNet,
|
||||||
|
IPv6: localIPv6,
|
||||||
|
IPv6Net: wgNetV6,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
||||||
|
|
||||||
|
err = manager.UpdateLocalIPs()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
srcIP string
|
||||||
|
dstIP string
|
||||||
|
proto fw.Protocol
|
||||||
|
srcPort uint16
|
||||||
|
dstPort uint16
|
||||||
|
ruleIP string
|
||||||
|
ruleProto fw.Protocol
|
||||||
|
ruleDstPort *fw.Port
|
||||||
|
ruleAction fw.Action
|
||||||
|
shouldBeBlocked bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "IPv6: allow TCP from peer",
|
||||||
|
srcIP: "fd00::1",
|
||||||
|
dstIP: "fd00::100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 443,
|
||||||
|
ruleIP: "fd00::1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleDstPort: &fw.Port{Values: []uint16{443}},
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6: allow UDP from peer",
|
||||||
|
srcIP: "fd00::1",
|
||||||
|
dstIP: "fd00::100",
|
||||||
|
proto: fw.ProtocolUDP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 53,
|
||||||
|
ruleIP: "fd00::1",
|
||||||
|
ruleProto: fw.ProtocolUDP,
|
||||||
|
ruleDstPort: &fw.Port{Values: []uint16{53}},
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6: allow ICMPv6 from peer",
|
||||||
|
srcIP: "fd00::1",
|
||||||
|
dstIP: "fd00::100",
|
||||||
|
proto: fw.ProtocolICMP,
|
||||||
|
ruleIP: "fd00::1",
|
||||||
|
ruleProto: fw.ProtocolICMP,
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6: block TCP without rule",
|
||||||
|
srcIP: "fd00::2",
|
||||||
|
dstIP: "fd00::100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 443,
|
||||||
|
ruleIP: "fd00::1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleDstPort: &fw.Port{Values: []uint16{443}},
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6: drop rule",
|
||||||
|
srcIP: "fd00::1",
|
||||||
|
dstIP: "fd00::100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 22,
|
||||||
|
ruleIP: "fd00::1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleDstPort: &fw.Port{Values: []uint16{22}},
|
||||||
|
ruleAction: fw.ActionDrop,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6: allow all protocols",
|
||||||
|
srcIP: "fd00::1",
|
||||||
|
dstIP: "fd00::100",
|
||||||
|
proto: fw.ProtocolUDP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 9999,
|
||||||
|
ruleIP: "fd00::1",
|
||||||
|
ruleProto: fw.ProtocolALL,
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6: v4 wildcard ICMP rule matches ICMPv6 via protoLayerMatches",
|
||||||
|
srcIP: "fd00::1",
|
||||||
|
dstIP: "fd00::100",
|
||||||
|
proto: fw.ProtocolICMP,
|
||||||
|
ruleIP: "0.0.0.0",
|
||||||
|
ruleProto: fw.ProtocolICMP,
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("IPv6 implicit DROP (no rules)", func(t *testing.T) {
|
||||||
|
packet := createTestPacket(t, "fd00::1", "fd00::100", fw.ProtocolTCP, 12345, 443)
|
||||||
|
isDropped := manager.FilterInbound(packet, 0)
|
||||||
|
require.True(t, isDropped, "IPv6 packet should be dropped when no rules exist")
|
||||||
|
})
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
if tc.ruleAction == fw.ActionDrop {
|
||||||
|
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
for _, rule := range rules {
|
||||||
|
require.NoError(t, manager.DeletePeerRule(rule))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), tc.ruleProto, nil, tc.ruleDstPort, tc.ruleAction, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, rules)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
for _, rule := range rules {
|
||||||
|
require.NoError(t, manager.DeletePeerRule(rule))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||||
|
isDropped := manager.FilterInbound(packet, 0)
|
||||||
|
require.Equal(t, tc.shouldBeBlocked, isDropped, "packet filter result mismatch")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func createTestPacket(t *testing.T, srcIP, dstIP string, proto fw.Protocol, srcPort, dstPort uint16) []byte {
|
func createTestPacket(t *testing.T, srcIP, dstIP string, proto fw.Protocol, srcPort, dstPort uint16) []byte {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
src := net.ParseIP(srcIP)
|
||||||
|
dst := net.ParseIP(dstIP)
|
||||||
|
|
||||||
buf := gopacket.NewSerializeBuffer()
|
buf := gopacket.NewSerializeBuffer()
|
||||||
opts := gopacket.SerializeOptions{
|
opts := gopacket.SerializeOptions{
|
||||||
ComputeChecksums: true,
|
ComputeChecksums: true,
|
||||||
FixLengths: true,
|
FixLengths: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
ipLayer := &layers.IPv4{
|
// Detect address family
|
||||||
Version: 4,
|
isV6 := src.To4() == nil
|
||||||
TTL: 64,
|
|
||||||
SrcIP: net.ParseIP(srcIP),
|
|
||||||
DstIP: net.ParseIP(dstIP),
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
switch proto {
|
|
||||||
case fw.ProtocolTCP:
|
|
||||||
ipLayer.Protocol = layers.IPProtocolTCP
|
|
||||||
tcp := &layers.TCP{
|
|
||||||
SrcPort: layers.TCPPort(srcPort),
|
|
||||||
DstPort: layers.TCPPort(dstPort),
|
|
||||||
}
|
|
||||||
err = tcp.SetNetworkLayerForChecksum(ipLayer)
|
|
||||||
require.NoError(t, err)
|
|
||||||
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcp)
|
|
||||||
|
|
||||||
case fw.ProtocolUDP:
|
if isV6 {
|
||||||
ipLayer.Protocol = layers.IPProtocolUDP
|
ip6 := &layers.IPv6{
|
||||||
udp := &layers.UDP{
|
Version: 6,
|
||||||
SrcPort: layers.UDPPort(srcPort),
|
HopLimit: 64,
|
||||||
DstPort: layers.UDPPort(dstPort),
|
SrcIP: src,
|
||||||
|
DstIP: dst,
|
||||||
}
|
}
|
||||||
err = udp.SetNetworkLayerForChecksum(ipLayer)
|
|
||||||
require.NoError(t, err)
|
|
||||||
err = gopacket.SerializeLayers(buf, opts, ipLayer, udp)
|
|
||||||
|
|
||||||
case fw.ProtocolICMP:
|
switch proto {
|
||||||
ipLayer.Protocol = layers.IPProtocolICMPv4
|
case fw.ProtocolTCP:
|
||||||
icmp := &layers.ICMPv4{
|
ip6.NextHeader = layers.IPProtocolTCP
|
||||||
TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0),
|
tcp := &layers.TCP{SrcPort: layers.TCPPort(srcPort), DstPort: layers.TCPPort(dstPort)}
|
||||||
|
_ = tcp.SetNetworkLayerForChecksum(ip6)
|
||||||
|
err = gopacket.SerializeLayers(buf, opts, ip6, tcp)
|
||||||
|
case fw.ProtocolUDP:
|
||||||
|
ip6.NextHeader = layers.IPProtocolUDP
|
||||||
|
udp := &layers.UDP{SrcPort: layers.UDPPort(srcPort), DstPort: layers.UDPPort(dstPort)}
|
||||||
|
_ = udp.SetNetworkLayerForChecksum(ip6)
|
||||||
|
err = gopacket.SerializeLayers(buf, opts, ip6, udp)
|
||||||
|
case fw.ProtocolICMP:
|
||||||
|
ip6.NextHeader = layers.IPProtocolICMPv6
|
||||||
|
icmp := &layers.ICMPv6{
|
||||||
|
TypeCode: layers.CreateICMPv6TypeCode(layers.ICMPv6TypeEchoRequest, 0),
|
||||||
|
}
|
||||||
|
_ = icmp.SetNetworkLayerForChecksum(ip6)
|
||||||
|
err = gopacket.SerializeLayers(buf, opts, ip6, icmp)
|
||||||
|
default:
|
||||||
|
err = gopacket.SerializeLayers(buf, opts, ip6)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ip4 := &layers.IPv4{
|
||||||
|
Version: 4,
|
||||||
|
TTL: 64,
|
||||||
|
SrcIP: src,
|
||||||
|
DstIP: dst,
|
||||||
}
|
}
|
||||||
err = gopacket.SerializeLayers(buf, opts, ipLayer, icmp)
|
|
||||||
|
|
||||||
default:
|
switch proto {
|
||||||
err = gopacket.SerializeLayers(buf, opts, ipLayer)
|
case fw.ProtocolTCP:
|
||||||
|
ip4.Protocol = layers.IPProtocolTCP
|
||||||
|
tcp := &layers.TCP{SrcPort: layers.TCPPort(srcPort), DstPort: layers.TCPPort(dstPort)}
|
||||||
|
_ = tcp.SetNetworkLayerForChecksum(ip4)
|
||||||
|
err = gopacket.SerializeLayers(buf, opts, ip4, tcp)
|
||||||
|
case fw.ProtocolUDP:
|
||||||
|
ip4.Protocol = layers.IPProtocolUDP
|
||||||
|
udp := &layers.UDP{SrcPort: layers.UDPPort(srcPort), DstPort: layers.UDPPort(dstPort)}
|
||||||
|
_ = udp.SetNetworkLayerForChecksum(ip4)
|
||||||
|
err = gopacket.SerializeLayers(buf, opts, ip4, udp)
|
||||||
|
case fw.ProtocolICMP:
|
||||||
|
ip4.Protocol = layers.IPProtocolICMPv4
|
||||||
|
icmp := &layers.ICMPv4{TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0)}
|
||||||
|
err = gopacket.SerializeLayers(buf, opts, ip4, icmp)
|
||||||
|
default:
|
||||||
|
err = gopacket.SerializeLayers(buf, opts, ip4)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -1498,3 +1681,103 @@ func TestRouteACLSet(t *testing.T) {
|
|||||||
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||||
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
|
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestRouteACLFilteringIPv6 tests IPv6 route ACL matching directly via routeACLsPass.
|
||||||
|
// Note: full FilterInbound for routed IPv6 traffic drops at the forwarder stage (IPv4-only)
|
||||||
|
// but the ACL decision itself is correct.
|
||||||
|
func TestRouteACLFilteringIPv6(t *testing.T) {
|
||||||
|
manager := setupRoutedManager(t, "10.10.0.100/16")
|
||||||
|
|
||||||
|
v6Dst := netip.MustParsePrefix("fd00:dead:beef::/48")
|
||||||
|
_, err := manager.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
|
||||||
|
fw.Network{Prefix: v6Dst},
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []uint16{80}},
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = manager.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
|
||||||
|
fw.Network{Prefix: netip.MustParsePrefix("fd00:dead:beef:1::/64")},
|
||||||
|
fw.ProtocolALL,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionDrop,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
srcIP netip.Addr
|
||||||
|
dstIP netip.Addr
|
||||||
|
proto gopacket.LayerType
|
||||||
|
srcPort uint16
|
||||||
|
dstPort uint16
|
||||||
|
allowed bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "IPv6 TCP to allowed dest",
|
||||||
|
srcIP: netip.MustParseAddr("fd00::1"),
|
||||||
|
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
|
||||||
|
proto: layers.LayerTypeTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 80,
|
||||||
|
allowed: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 TCP wrong port",
|
||||||
|
srcIP: netip.MustParseAddr("fd00::1"),
|
||||||
|
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
|
||||||
|
proto: layers.LayerTypeTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 443,
|
||||||
|
allowed: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 UDP not matched by TCP rule",
|
||||||
|
srcIP: netip.MustParseAddr("fd00::1"),
|
||||||
|
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
|
||||||
|
proto: layers.LayerTypeUDP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 80,
|
||||||
|
allowed: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 ICMPv6 matches ICMP rule via protoLayerMatches",
|
||||||
|
srcIP: netip.MustParseAddr("fd00::1"),
|
||||||
|
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
|
||||||
|
proto: layers.LayerTypeICMPv6,
|
||||||
|
allowed: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 to denied subnet",
|
||||||
|
srcIP: netip.MustParseAddr("fd00::1"),
|
||||||
|
dstIP: netip.MustParseAddr("fd00:dead:beef:1::1"),
|
||||||
|
proto: layers.LayerTypeTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 80,
|
||||||
|
allowed: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 source outside allowed range",
|
||||||
|
srcIP: netip.MustParseAddr("fe80::1"),
|
||||||
|
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
|
||||||
|
proto: layers.LayerTypeTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 80,
|
||||||
|
allowed: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
_, pass := manager.routeACLsPass(tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||||
|
require.Equal(t, tc.allowed, pass, "route ACL result mismatch")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -527,11 +527,16 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
d := &decoder{
|
d := &decoder{
|
||||||
decoded: []gopacket.LayerType{},
|
decoded: []gopacket.LayerType{},
|
||||||
}
|
}
|
||||||
d.parser = gopacket.NewDecodingLayerParser(
|
d.parser4 = gopacket.NewDecodingLayerParser(
|
||||||
layers.LayerTypeIPv4,
|
layers.LayerTypeIPv4,
|
||||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
)
|
)
|
||||||
d.parser.IgnoreUnsupported = true
|
d.parser4.IgnoreUnsupported = true
|
||||||
|
d.parser6 = gopacket.NewDecodingLayerParser(
|
||||||
|
layers.LayerTypeIPv6,
|
||||||
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
|
)
|
||||||
|
d.parser6.IgnoreUnsupported = true
|
||||||
return d
|
return d
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -630,11 +635,16 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
d := &decoder{
|
d := &decoder{
|
||||||
decoded: []gopacket.LayerType{},
|
decoded: []gopacket.LayerType{},
|
||||||
}
|
}
|
||||||
d.parser = gopacket.NewDecodingLayerParser(
|
d.parser4 = gopacket.NewDecodingLayerParser(
|
||||||
layers.LayerTypeIPv4,
|
layers.LayerTypeIPv4,
|
||||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
)
|
)
|
||||||
d.parser.IgnoreUnsupported = true
|
d.parser4.IgnoreUnsupported = true
|
||||||
|
d.parser6 = gopacket.NewDecodingLayerParser(
|
||||||
|
layers.LayerTypeIPv6,
|
||||||
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
|
)
|
||||||
|
d.parser6.IgnoreUnsupported = true
|
||||||
return d
|
return d
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -1040,8 +1050,8 @@ func TestMSSClamping(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
require.True(t, manager.mssClampEnabled, "MSS clamping should be enabled by default")
|
require.True(t, manager.mssClampEnabled, "MSS clamping should be enabled by default")
|
||||||
expectedMSSValue := uint16(1280 - ipTCPHeaderMinSize)
|
require.Equal(t, uint16(1280-ipv4TCPHeaderMinSize), manager.mssClampValueIPv4, "IPv4 MSS clamp value should be MTU - 40")
|
||||||
require.Equal(t, expectedMSSValue, manager.mssClampValue, "MSS clamp value should be MTU - 40")
|
require.Equal(t, uint16(1280-ipv6TCPHeaderMinSize), manager.mssClampValueIPv6, "IPv6 MSS clamp value should be MTU - 60")
|
||||||
|
|
||||||
err = manager.UpdateLocalIPs()
|
err = manager.UpdateLocalIPs()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -1059,7 +1069,7 @@ func TestMSSClamping(t *testing.T) {
|
|||||||
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
|
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
|
||||||
require.Equal(t, uint8(layers.TCPOptionKindMSS), uint8(d.tcp.Options[0].OptionType))
|
require.Equal(t, uint8(layers.TCPOptionKindMSS), uint8(d.tcp.Options[0].OptionType))
|
||||||
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
|
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
|
||||||
require.Equal(t, expectedMSSValue, actualMSS, "MSS should be clamped to MTU - 40")
|
require.Equal(t, manager.mssClampValueIPv4, actualMSS, "MSS should be clamped to MTU - 40")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("SYN packet with low MSS unchanged", func(t *testing.T) {
|
t.Run("SYN packet with low MSS unchanged", func(t *testing.T) {
|
||||||
@@ -1083,7 +1093,7 @@ func TestMSSClamping(t *testing.T) {
|
|||||||
d := parsePacket(t, packet)
|
d := parsePacket(t, packet)
|
||||||
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
|
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
|
||||||
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
|
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
|
||||||
require.Equal(t, expectedMSSValue, actualMSS, "MSS in SYN-ACK should be clamped")
|
require.Equal(t, manager.mssClampValueIPv4, actualMSS, "MSS in SYN-ACK should be clamped")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Non-SYN packet unchanged", func(t *testing.T) {
|
t.Run("Non-SYN packet unchanged", func(t *testing.T) {
|
||||||
@@ -1255,13 +1265,18 @@ func TestShouldForward(t *testing.T) {
|
|||||||
d := &decoder{
|
d := &decoder{
|
||||||
decoded: []gopacket.LayerType{},
|
decoded: []gopacket.LayerType{},
|
||||||
}
|
}
|
||||||
d.parser = gopacket.NewDecodingLayerParser(
|
d.parser4 = gopacket.NewDecodingLayerParser(
|
||||||
layers.LayerTypeIPv4,
|
layers.LayerTypeIPv4,
|
||||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
)
|
)
|
||||||
d.parser.IgnoreUnsupported = true
|
d.parser4.IgnoreUnsupported = true
|
||||||
|
d.parser6 = gopacket.NewDecodingLayerParser(
|
||||||
|
layers.LayerTypeIPv6,
|
||||||
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
|
)
|
||||||
|
d.parser6.IgnoreUnsupported = true
|
||||||
|
|
||||||
err = d.parser.DecodeLayers(buf.Bytes(), &d.decoded)
|
err = d.decodePacket(buf.Bytes())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
return d
|
return d
|
||||||
@@ -1321,6 +1336,44 @@ func TestShouldForward(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add IPv6 to the interface and test dual-stack cases
|
||||||
|
wgIPv6 := netip.MustParseAddr("fd00::1")
|
||||||
|
otherIPv6 := netip.MustParseAddr("fd00::2")
|
||||||
|
ifaceMock.AddressFunc = func() wgaddr.Address {
|
||||||
|
return wgaddr.Address{
|
||||||
|
IP: wgIP,
|
||||||
|
Network: netip.PrefixFrom(wgIP, 24),
|
||||||
|
IPv6: wgIPv6,
|
||||||
|
IPv6Net: netip.PrefixFrom(wgIPv6, 64),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-create manager to pick up the new address with IPv6
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
manager, err = Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
v6Cases := []struct {
|
||||||
|
name string
|
||||||
|
dstIP netip.Addr
|
||||||
|
expected bool
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{"v6 traffic to other address", otherIPv6, true, "should forward v6 traffic not destined to our v6 address"},
|
||||||
|
{"v6 traffic to our v6 IP", wgIPv6, false, "should not forward traffic destined to our v6 address"},
|
||||||
|
{"v4 traffic to other with v6 configured", otherIP, true, "should forward v4 traffic when v6 configured"},
|
||||||
|
{"v4 traffic to our v4 IP with v6 configured", wgIP, false, "should not forward traffic to our v4 address"},
|
||||||
|
}
|
||||||
|
for _, tt := range v6Cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
manager.localForwarding = true
|
||||||
|
manager.netstack = false
|
||||||
|
decoder := createTCPDecoder(8080)
|
||||||
|
result := manager.shouldForward(decoder, tt.dstIP)
|
||||||
|
require.Equal(t, tt.expected, result, tt.description)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
// Configure manager
|
// Configure manager
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
package forwarder
|
package forwarder
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"net"
|
||||||
|
"strconv"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
@@ -47,17 +48,23 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress {
|
|||||||
func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
|
func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
|
||||||
var written int
|
var written int
|
||||||
for _, pkt := range pkts.AsSlice() {
|
for _, pkt := range pkts.AsSlice() {
|
||||||
netHeader := header.IPv4(pkt.NetworkHeader().View().AsSlice())
|
|
||||||
|
|
||||||
data := stack.PayloadSince(pkt.NetworkHeader())
|
data := stack.PayloadSince(pkt.NetworkHeader())
|
||||||
if data == nil {
|
if data == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send the packet through WireGuard
|
raw := pkt.NetworkHeader().View().AsSlice()
|
||||||
address := netHeader.DestinationAddress()
|
if len(raw) == 0 {
|
||||||
err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice())
|
continue
|
||||||
if err != nil {
|
}
|
||||||
|
var address tcpip.Address
|
||||||
|
if raw[0]>>4 == 6 {
|
||||||
|
address = header.IPv6(raw).DestinationAddress()
|
||||||
|
} else {
|
||||||
|
address = header.IPv4(raw).DestinationAddress()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice()); err != nil {
|
||||||
e.logger.Error1("CreateOutboundPacket: %v", err)
|
e.logger.Error1("CreateOutboundPacket: %v", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -103,5 +110,7 @@ type epID stack.TransportEndpointID
|
|||||||
|
|
||||||
func (i epID) String() string {
|
func (i epID) String() string {
|
||||||
// src and remote is swapped
|
// src and remote is swapped
|
||||||
return fmt.Sprintf("%s:%d → %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
|
return net.JoinHostPort(i.RemoteAddress.String(), strconv.Itoa(int(i.RemotePort))) +
|
||||||
|
" → " +
|
||||||
|
net.JoinHostPort(i.LocalAddress.String(), strconv.Itoa(int(i.LocalPort)))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
|
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||||
@@ -36,25 +37,31 @@ type Forwarder struct {
|
|||||||
logger *nblog.Logger
|
logger *nblog.Logger
|
||||||
flowLogger nftypes.FlowLogger
|
flowLogger nftypes.FlowLogger
|
||||||
// ruleIdMap is used to store the rule ID for a given connection
|
// ruleIdMap is used to store the rule ID for a given connection
|
||||||
ruleIdMap sync.Map
|
ruleIdMap sync.Map
|
||||||
stack *stack.Stack
|
stack *stack.Stack
|
||||||
endpoint *endpoint
|
endpoint *endpoint
|
||||||
udpForwarder *udpForwarder
|
udpForwarder *udpForwarder
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
ip tcpip.Address
|
ip tcpip.Address
|
||||||
netstack bool
|
ipv6 tcpip.Address
|
||||||
hasRawICMPAccess bool
|
netstack bool
|
||||||
pingSemaphore chan struct{}
|
hasRawICMPAccess bool
|
||||||
|
hasRawICMPv6Access bool
|
||||||
|
pingSemaphore chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
|
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
|
||||||
s := stack.New(stack.Options{
|
s := stack.New(stack.Options{
|
||||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
NetworkProtocols: []stack.NetworkProtocolFactory{
|
||||||
|
ipv4.NewProtocol,
|
||||||
|
ipv6.NewProtocol,
|
||||||
|
},
|
||||||
TransportProtocols: []stack.TransportProtocolFactory{
|
TransportProtocols: []stack.TransportProtocolFactory{
|
||||||
tcp.NewProtocol,
|
tcp.NewProtocol,
|
||||||
udp.NewProtocol,
|
udp.NewProtocol,
|
||||||
icmp.NewProtocol4,
|
icmp.NewProtocol4,
|
||||||
|
icmp.NewProtocol6,
|
||||||
},
|
},
|
||||||
HandleLocal: false,
|
HandleLocal: false,
|
||||||
})
|
})
|
||||||
@@ -73,7 +80,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
|||||||
protoAddr := tcpip.ProtocolAddress{
|
protoAddr := tcpip.ProtocolAddress{
|
||||||
Protocol: ipv4.ProtocolNumber,
|
Protocol: ipv4.ProtocolNumber,
|
||||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||||
Address: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
Address: tcpip.AddrFrom4(iface.Address().IP.As4()),
|
||||||
PrefixLen: iface.Address().Network.Bits(),
|
PrefixLen: iface.Address().Network.Bits(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -82,6 +89,19 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
|||||||
return nil, fmt.Errorf("failed to add protocol address: %s", err)
|
return nil, fmt.Errorf("failed to add protocol address: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if v6 := iface.Address().IPv6; v6.IsValid() {
|
||||||
|
v6Addr := tcpip.ProtocolAddress{
|
||||||
|
Protocol: ipv6.ProtocolNumber,
|
||||||
|
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||||
|
Address: tcpip.AddrFrom16(v6.As16()),
|
||||||
|
PrefixLen: iface.Address().IPv6Net.Bits(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if err := s.AddProtocolAddress(nicID, v6Addr, stack.AddressProperties{}); err != nil {
|
||||||
|
return nil, fmt.Errorf("add IPv6 protocol address: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
defaultSubnet, err := tcpip.NewSubnet(
|
defaultSubnet, err := tcpip.NewSubnet(
|
||||||
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
|
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
|
||||||
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
|
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
|
||||||
@@ -90,6 +110,14 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
|||||||
return nil, fmt.Errorf("creating default subnet: %w", err)
|
return nil, fmt.Errorf("creating default subnet: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defaultSubnetV6, err := tcpip.NewSubnet(
|
||||||
|
tcpip.AddrFrom16([16]byte{}),
|
||||||
|
tcpip.MaskFromBytes(make([]byte, 16)),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("creating default v6 subnet: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := s.SetPromiscuousMode(nicID, true); err != nil {
|
if err := s.SetPromiscuousMode(nicID, true); err != nil {
|
||||||
return nil, fmt.Errorf("set promiscuous mode: %s", err)
|
return nil, fmt.Errorf("set promiscuous mode: %s", err)
|
||||||
}
|
}
|
||||||
@@ -98,10 +126,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.SetRouteTable([]tcpip.Route{
|
s.SetRouteTable([]tcpip.Route{
|
||||||
{
|
{Destination: defaultSubnet, NIC: nicID},
|
||||||
Destination: defaultSubnet,
|
{Destination: defaultSubnetV6, NIC: nicID},
|
||||||
NIC: nicID,
|
|
||||||
},
|
|
||||||
})
|
})
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
@@ -114,7 +140,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
|||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
netstack: netstack,
|
netstack: netstack,
|
||||||
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
ip: tcpip.AddrFrom4(iface.Address().IP.As4()),
|
||||||
|
ipv6: addrFromNetipAddr(iface.Address().IPv6),
|
||||||
pingSemaphore: make(chan struct{}, 3),
|
pingSemaphore: make(chan struct{}, 3),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -131,7 +158,10 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
|||||||
udpForwarder := udp.NewForwarder(s, f.handleUDP)
|
udpForwarder := udp.NewForwarder(s, f.handleUDP)
|
||||||
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
|
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
|
||||||
|
|
||||||
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP)
|
// ICMP is handled directly in InjectIncomingPacket, bypassing gVisor's
|
||||||
|
// network layer. This avoids duplicate echo replies (v4) and the v6
|
||||||
|
// auto-reply bug where gVisor responds at the network layer before
|
||||||
|
// our transport handler fires.
|
||||||
|
|
||||||
f.checkICMPCapability()
|
f.checkICMPCapability()
|
||||||
|
|
||||||
@@ -140,8 +170,30 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
|
func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
|
||||||
if len(payload) < header.IPv4MinimumSize {
|
if len(payload) == 0 {
|
||||||
return fmt.Errorf("packet too small: %d bytes", len(payload))
|
return fmt.Errorf("empty packet")
|
||||||
|
}
|
||||||
|
|
||||||
|
var protoNum tcpip.NetworkProtocolNumber
|
||||||
|
switch payload[0] >> 4 {
|
||||||
|
case 4:
|
||||||
|
if len(payload) < header.IPv4MinimumSize {
|
||||||
|
return fmt.Errorf("IPv4 packet too small: %d bytes", len(payload))
|
||||||
|
}
|
||||||
|
if f.handleICMPDirect(payload) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
protoNum = ipv4.ProtocolNumber
|
||||||
|
case 6:
|
||||||
|
if len(payload) < header.IPv6MinimumSize {
|
||||||
|
return fmt.Errorf("IPv6 packet too small: %d bytes", len(payload))
|
||||||
|
}
|
||||||
|
if f.handleICMPDirect(payload) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
protoNum = ipv6.ProtocolNumber
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown IP version: %d", payload[0]>>4)
|
||||||
}
|
}
|
||||||
|
|
||||||
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||||
@@ -150,11 +202,95 @@ func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
|
|||||||
defer pkt.DecRef()
|
defer pkt.DecRef()
|
||||||
|
|
||||||
if f.endpoint.dispatcher != nil {
|
if f.endpoint.dispatcher != nil {
|
||||||
f.endpoint.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
|
f.endpoint.dispatcher.DeliverNetworkPacket(protoNum, pkt)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleICMPDirect intercepts ICMP packets from raw IP payloads before they
|
||||||
|
// enter gVisor. It synthesizes the TransportEndpointID and PacketBuffer that
|
||||||
|
// the existing handlers expect, then dispatches to handleICMP/handleICMPv6.
|
||||||
|
// This bypasses gVisor's network layer which causes duplicate v4 echo replies
|
||||||
|
// and auto-replies to all v6 echo requests in promiscuous mode.
|
||||||
|
//
|
||||||
|
// Unlike gVisor's network layer, this does not validate ICMP checksums or
|
||||||
|
// reassemble IP fragments. Fragmented ICMP packets fall through to gVisor.
|
||||||
|
func parseICMPv4(payload []byte) (ipHdrLen int, src, dst tcpip.Address, ok bool) {
|
||||||
|
ip := header.IPv4(payload)
|
||||||
|
if ip.Protocol() != uint8(header.ICMPv4ProtocolNumber) {
|
||||||
|
return 0, src, dst, false
|
||||||
|
}
|
||||||
|
if ip.FragmentOffset() != 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 {
|
||||||
|
return 0, src, dst, false
|
||||||
|
}
|
||||||
|
ipHdrLen = int(ip.HeaderLength())
|
||||||
|
if len(payload)-ipHdrLen < header.ICMPv4MinimumSize {
|
||||||
|
return 0, src, dst, false
|
||||||
|
}
|
||||||
|
return ipHdrLen, ip.SourceAddress(), ip.DestinationAddress(), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseICMPv6(payload []byte) (ipHdrLen int, src, dst tcpip.Address, ok bool) {
|
||||||
|
ip := header.IPv6(payload)
|
||||||
|
if ip.NextHeader() != uint8(header.ICMPv6ProtocolNumber) {
|
||||||
|
return 0, src, dst, false
|
||||||
|
}
|
||||||
|
ipHdrLen = header.IPv6MinimumSize
|
||||||
|
if len(payload)-ipHdrLen < header.ICMPv6MinimumSize {
|
||||||
|
return 0, src, dst, false
|
||||||
|
}
|
||||||
|
return ipHdrLen, ip.SourceAddress(), ip.DestinationAddress(), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) handleICMPDirect(payload []byte) bool {
|
||||||
|
var (
|
||||||
|
ipHdrLen int
|
||||||
|
srcAddr tcpip.Address
|
||||||
|
dstAddr tcpip.Address
|
||||||
|
ok bool
|
||||||
|
)
|
||||||
|
switch payload[0] >> 4 {
|
||||||
|
case 4:
|
||||||
|
ipHdrLen, srcAddr, dstAddr, ok = parseICMPv4(payload)
|
||||||
|
case 6:
|
||||||
|
ipHdrLen, srcAddr, dstAddr, ok = parseICMPv6(payload)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Let gVisor handle ICMP destined for our own addresses natively.
|
||||||
|
// Its network-layer auto-reply is correct and efficient for local traffic.
|
||||||
|
if f.ip.Equal(dstAddr) || f.ipv6.Equal(dstAddr) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
id := stack.TransportEndpointID{
|
||||||
|
LocalAddress: dstAddr,
|
||||||
|
RemoteAddress: srcAddr,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build a PacketBuffer with headers consumed the same way gVisor would.
|
||||||
|
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||||
|
Payload: buffer.MakeWithData(payload),
|
||||||
|
})
|
||||||
|
defer pkt.DecRef()
|
||||||
|
|
||||||
|
if _, ok := pkt.NetworkHeader().Consume(ipHdrLen); !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
icmpPayload := payload[ipHdrLen:]
|
||||||
|
if _, ok := pkt.TransportHeader().Consume(len(icmpPayload)); !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if payload[0]>>4 == 6 {
|
||||||
|
return f.handleICMPv6(id, pkt)
|
||||||
|
}
|
||||||
|
return f.handleICMP(id, pkt)
|
||||||
|
}
|
||||||
|
|
||||||
// Stop gracefully shuts down the forwarder
|
// Stop gracefully shuts down the forwarder
|
||||||
func (f *Forwarder) Stop() {
|
func (f *Forwarder) Stop() {
|
||||||
f.cancel()
|
f.cancel()
|
||||||
@@ -167,11 +303,14 @@ func (f *Forwarder) Stop() {
|
|||||||
f.stack.Wait()
|
f.stack.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
|
func (f *Forwarder) determineDialAddr(addr tcpip.Address) netip.Addr {
|
||||||
if f.netstack && f.ip.Equal(addr) {
|
if f.netstack && f.ip.Equal(addr) {
|
||||||
return net.IPv4(127, 0, 0, 1)
|
return netip.AddrFrom4([4]byte{127, 0, 0, 1})
|
||||||
}
|
}
|
||||||
return addr.AsSlice()
|
if f.netstack && f.ipv6.Equal(addr) {
|
||||||
|
return netip.IPv6Loopback()
|
||||||
|
}
|
||||||
|
return addrToNetipAddr(addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) {
|
func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) {
|
||||||
@@ -205,23 +344,50 @@ func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKe
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// addrFromNetipAddr converts a netip.Addr to a gvisor tcpip.Address without allocating.
|
||||||
|
func addrFromNetipAddr(addr netip.Addr) tcpip.Address {
|
||||||
|
if !addr.IsValid() {
|
||||||
|
return tcpip.Address{}
|
||||||
|
}
|
||||||
|
if addr.Is4() {
|
||||||
|
return tcpip.AddrFrom4(addr.As4())
|
||||||
|
}
|
||||||
|
return tcpip.AddrFrom16(addr.As16())
|
||||||
|
}
|
||||||
|
|
||||||
|
// addrToNetipAddr converts a gvisor tcpip.Address to netip.Addr without allocating.
|
||||||
|
func addrToNetipAddr(addr tcpip.Address) netip.Addr {
|
||||||
|
switch addr.Len() {
|
||||||
|
case 4:
|
||||||
|
return netip.AddrFrom4(addr.As4())
|
||||||
|
case 16:
|
||||||
|
return netip.AddrFrom16(addr.As16())
|
||||||
|
default:
|
||||||
|
return netip.Addr{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// checkICMPCapability tests whether we have raw ICMP socket access at startup.
|
// checkICMPCapability tests whether we have raw ICMP socket access at startup.
|
||||||
func (f *Forwarder) checkICMPCapability() {
|
func (f *Forwarder) checkICMPCapability() {
|
||||||
|
f.hasRawICMPAccess = probeRawICMP("ip4:icmp", "0.0.0.0", f.logger)
|
||||||
|
f.hasRawICMPv6Access = probeRawICMP("ip6:ipv6-icmp", "::", f.logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
func probeRawICMP(network, addr string, logger *nblog.Logger) bool {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
lc := net.ListenConfig{}
|
lc := net.ListenConfig{}
|
||||||
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
conn, err := lc.ListenPacket(ctx, network, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.hasRawICMPAccess = false
|
logger.Debug1("forwarder: no raw %s socket access, will use ping binary fallback", network)
|
||||||
f.logger.Debug("forwarder: No raw ICMP socket access, will use ping binary fallback")
|
return false
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := conn.Close(); err != nil {
|
if err := conn.Close(); err != nil {
|
||||||
f.logger.Debug1("forwarder: Failed to close ICMP capability test socket: %v", err)
|
logger.Debug2("forwarder: failed to close %s capability test socket: %v", network, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
f.hasRawICMPAccess = true
|
logger.Debug1("forwarder: raw %s socket access available", network)
|
||||||
f.logger.Debug("forwarder: Raw ICMP socket access available")
|
return true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt *stack.PacketBu
|
|||||||
}
|
}
|
||||||
|
|
||||||
icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice()
|
icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice()
|
||||||
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 100*time.Millisecond)
|
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), false, 100*time.Millisecond)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.logger.Error2("forwarder: Failed to forward ICMP packet for %v: %v", epID(id), err)
|
f.logger.Error2("forwarder: Failed to forward ICMP packet for %v: %v", epID(id), err)
|
||||||
return true
|
return true
|
||||||
@@ -58,7 +58,7 @@ func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointI
|
|||||||
defer func() { <-f.pingSemaphore }()
|
defer func() { <-f.pingSemaphore }()
|
||||||
|
|
||||||
if f.hasRawICMPAccess {
|
if f.hasRawICMPAccess {
|
||||||
f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
|
f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes, false)
|
||||||
} else {
|
} else {
|
||||||
f.handleICMPViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
|
f.handleICMPViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
|
||||||
}
|
}
|
||||||
@@ -72,18 +72,23 @@ func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointI
|
|||||||
|
|
||||||
// forwardICMPPacket creates a raw ICMP socket and sends the packet, returning the connection.
|
// forwardICMPPacket creates a raw ICMP socket and sends the packet, returning the connection.
|
||||||
// The caller is responsible for closing the returned connection.
|
// The caller is responsible for closing the returned connection.
|
||||||
func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, timeout time.Duration) (net.PacketConn, error) {
|
func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, v6 bool, timeout time.Duration) (net.PacketConn, error) {
|
||||||
ctx, cancel := context.WithTimeout(f.ctx, timeout)
|
ctx, cancel := context.WithTimeout(f.ctx, timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
network, listenAddr := "ip4:icmp", "0.0.0.0"
|
||||||
|
if v6 {
|
||||||
|
network, listenAddr = "ip6:ipv6-icmp", "::"
|
||||||
|
}
|
||||||
|
|
||||||
lc := net.ListenConfig{}
|
lc := net.ListenConfig{}
|
||||||
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
conn, err := lc.ListenPacket(ctx, network, listenAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create ICMP socket: %w", err)
|
return nil, fmt.Errorf("create ICMP socket: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||||
dst := &net.IPAddr{IP: dstIP}
|
dst := &net.IPAddr{IP: dstIP.AsSlice()}
|
||||||
|
|
||||||
if _, err = conn.WriteTo(payload, dst); err != nil {
|
if _, err = conn.WriteTo(payload, dst); err != nil {
|
||||||
if closeErr := conn.Close(); closeErr != nil {
|
if closeErr := conn.Close(); closeErr != nil {
|
||||||
@@ -98,11 +103,11 @@ func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []by
|
|||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleICMPViaSocket handles ICMP echo requests using raw sockets.
|
// handleICMPViaSocket handles ICMP echo requests using raw sockets for both v4 and v6.
|
||||||
func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
|
func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int, v6 bool) {
|
||||||
sendTime := time.Now()
|
sendTime := time.Now()
|
||||||
|
|
||||||
conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, 5*time.Second)
|
conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, v6, 5*time.Second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.logger.Error2("forwarder: Failed to send ICMP packet for %v: %v", epID(id), err)
|
f.logger.Error2("forwarder: Failed to send ICMP packet for %v: %v", epID(id), err)
|
||||||
return
|
return
|
||||||
@@ -113,16 +118,20 @@ func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndp
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
txBytes := f.handleEchoResponse(conn, id)
|
txBytes := f.handleEchoResponse(conn, id, v6)
|
||||||
rtt := time.Since(sendTime).Round(10 * time.Microsecond)
|
rtt := time.Since(sendTime).Round(10 * time.Microsecond)
|
||||||
|
|
||||||
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, raw socket)",
|
proto := "ICMP"
|
||||||
epID(id), icmpType, icmpCode, rtt)
|
if v6 {
|
||||||
|
proto = "ICMPv6"
|
||||||
|
}
|
||||||
|
f.logger.Trace5("forwarder: Forwarded %s echo reply %v type %v code %v (rtt=%v, raw socket)",
|
||||||
|
proto, epID(id), icmpType, icmpCode, rtt)
|
||||||
|
|
||||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID) int {
|
func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID, v6 bool) int {
|
||||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||||
f.logger.Error1("forwarder: Failed to set read deadline for ICMP response: %v", err)
|
f.logger.Error1("forwarder: Failed to set read deadline for ICMP response: %v", err)
|
||||||
return 0
|
return 0
|
||||||
@@ -137,6 +146,19 @@ func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEn
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if v6 {
|
||||||
|
// Recompute checksum: the raw socket response has a checksum computed
|
||||||
|
// over the real endpoint addresses, but we inject with overlay addresses.
|
||||||
|
icmpHdr := header.ICMPv6(response[:n])
|
||||||
|
icmpHdr.SetChecksum(0)
|
||||||
|
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
|
||||||
|
Header: icmpHdr,
|
||||||
|
Src: id.LocalAddress,
|
||||||
|
Dst: id.RemoteAddress,
|
||||||
|
}))
|
||||||
|
return f.injectICMPv6Reply(id, response[:n])
|
||||||
|
}
|
||||||
|
|
||||||
return f.injectICMPReply(id, response[:n])
|
return f.injectICMPReply(id, response[:n])
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -150,19 +172,23 @@ func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.T
|
|||||||
txPackets = 1
|
txPackets = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
|
srcIp := addrToNetipAddr(id.RemoteAddress)
|
||||||
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
|
dstIp := addrToNetipAddr(id.LocalAddress)
|
||||||
|
|
||||||
|
proto := nftypes.ICMP
|
||||||
|
if srcIp.Is6() {
|
||||||
|
proto = nftypes.ICMPv6
|
||||||
|
}
|
||||||
|
|
||||||
fields := nftypes.EventFields{
|
fields := nftypes.EventFields{
|
||||||
FlowID: flowID,
|
FlowID: flowID,
|
||||||
Type: typ,
|
Type: typ,
|
||||||
Direction: nftypes.Ingress,
|
Direction: nftypes.Ingress,
|
||||||
Protocol: nftypes.ICMP,
|
Protocol: proto,
|
||||||
// TODO: handle ipv6
|
SourceIP: srcIp,
|
||||||
SourceIP: srcIp,
|
DestIP: dstIp,
|
||||||
DestIP: dstIp,
|
ICMPType: icmpType,
|
||||||
ICMPType: icmpType,
|
ICMPCode: icmpCode,
|
||||||
ICMPCode: icmpCode,
|
|
||||||
|
|
||||||
RxBytes: rxBytes,
|
RxBytes: rxBytes,
|
||||||
TxBytes: txBytes,
|
TxBytes: txBytes,
|
||||||
@@ -209,26 +235,164 @@ func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpoi
|
|||||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleICMPv6 handles ICMPv6 packets from the network stack.
|
||||||
|
func (f *Forwarder) handleICMPv6(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
|
||||||
|
icmpHdr := header.ICMPv6(pkt.TransportHeader().View().AsSlice())
|
||||||
|
|
||||||
|
flowID := uuid.New()
|
||||||
|
f.sendICMPEvent(nftypes.TypeStart, flowID, id, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 0, 0)
|
||||||
|
|
||||||
|
if icmpHdr.Type() == header.ICMPv6EchoRequest {
|
||||||
|
return f.handleICMPv6Echo(flowID, id, pkt, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// For non-echo types (Destination Unreachable, Packet Too Big, etc), forward without waiting
|
||||||
|
if !f.hasRawICMPv6Access {
|
||||||
|
f.logger.Debug2("forwarder: Cannot handle ICMPv6 type %v without raw socket access for %v", icmpHdr.Type(), epID(id))
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice()
|
||||||
|
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), true, 100*time.Millisecond)
|
||||||
|
if err != nil {
|
||||||
|
f.logger.Error2("forwarder: Failed to forward ICMPv6 packet for %v: %v", epID(id), err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
f.logger.Debug1("forwarder: Failed to close ICMPv6 socket: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleICMPv6Echo handles ICMPv6 echo requests via raw socket or ping binary fallback.
|
||||||
|
func (f *Forwarder) handleICMPv6Echo(flowID uuid.UUID, id stack.TransportEndpointID, pkt *stack.PacketBuffer, icmpType, icmpCode uint8) bool {
|
||||||
|
select {
|
||||||
|
case f.pingSemaphore <- struct{}{}:
|
||||||
|
icmpData := stack.PayloadSince(pkt.TransportHeader()).ToSlice()
|
||||||
|
rxBytes := pkt.Size()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() { <-f.pingSemaphore }()
|
||||||
|
|
||||||
|
if f.hasRawICMPv6Access {
|
||||||
|
f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes, true)
|
||||||
|
} else {
|
||||||
|
f.handleICMPv6ViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
default:
|
||||||
|
f.logger.Debug3("forwarder: ICMPv6 rate limit exceeded for %v type %v code %v", epID(id), icmpType, icmpCode)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleICMPv6ViaPing uses the system ping6 binary for ICMPv6 echo.
|
||||||
|
func (f *Forwarder) handleICMPv6ViaPing(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
|
||||||
|
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||||
|
cmd := buildPingCommand(ctx, dstIP, 5*time.Second)
|
||||||
|
|
||||||
|
pingStart := time.Now()
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
f.logger.Warn4("forwarder: Ping6 failed for %v type %v code %v: %v", epID(id), icmpType, icmpCode, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rtt := time.Since(pingStart).Round(10 * time.Microsecond)
|
||||||
|
|
||||||
|
f.logger.Trace3("forwarder: Forwarded ICMPv6 echo request %v type %v code %v",
|
||||||
|
epID(id), icmpType, icmpCode)
|
||||||
|
|
||||||
|
txBytes := f.synthesizeICMPv6EchoReply(id, icmpData)
|
||||||
|
|
||||||
|
f.logger.Trace4("forwarder: Forwarded ICMPv6 echo reply %v type %v code %v (rtt=%v, ping binary)",
|
||||||
|
epID(id), icmpType, icmpCode, rtt)
|
||||||
|
|
||||||
|
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
// synthesizeICMPv6EchoReply creates an ICMPv6 echo reply and injects it back.
|
||||||
|
func (f *Forwarder) synthesizeICMPv6EchoReply(id stack.TransportEndpointID, icmpData []byte) int {
|
||||||
|
replyICMP := make([]byte, len(icmpData))
|
||||||
|
copy(replyICMP, icmpData)
|
||||||
|
|
||||||
|
replyHdr := header.ICMPv6(replyICMP)
|
||||||
|
replyHdr.SetType(header.ICMPv6EchoReply)
|
||||||
|
replyHdr.SetChecksum(0)
|
||||||
|
// ICMPv6Checksum computes the pseudo-header internally from Src/Dst.
|
||||||
|
// Header contains the full ICMP message, so PayloadCsum/PayloadLen are zero.
|
||||||
|
replyHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
|
||||||
|
Header: replyHdr,
|
||||||
|
Src: id.LocalAddress,
|
||||||
|
Dst: id.RemoteAddress,
|
||||||
|
}))
|
||||||
|
|
||||||
|
return f.injectICMPv6Reply(id, replyICMP)
|
||||||
|
}
|
||||||
|
|
||||||
|
// injectICMPv6Reply wraps an ICMPv6 payload in an IPv6 header and sends to the peer.
|
||||||
|
func (f *Forwarder) injectICMPv6Reply(id stack.TransportEndpointID, icmpPayload []byte) int {
|
||||||
|
ipHdr := make([]byte, header.IPv6MinimumSize)
|
||||||
|
ip := header.IPv6(ipHdr)
|
||||||
|
ip.Encode(&header.IPv6Fields{
|
||||||
|
PayloadLength: uint16(len(icmpPayload)),
|
||||||
|
TransportProtocol: header.ICMPv6ProtocolNumber,
|
||||||
|
HopLimit: 64,
|
||||||
|
SrcAddr: id.LocalAddress,
|
||||||
|
DstAddr: id.RemoteAddress,
|
||||||
|
})
|
||||||
|
|
||||||
|
fullPacket := make([]byte, 0, len(ipHdr)+len(icmpPayload))
|
||||||
|
fullPacket = append(fullPacket, ipHdr...)
|
||||||
|
fullPacket = append(fullPacket, icmpPayload...)
|
||||||
|
|
||||||
|
if err := f.endpoint.device.CreateOutboundPacket(fullPacket, id.RemoteAddress.AsSlice()); err != nil {
|
||||||
|
f.logger.Error1("forwarder: Failed to send ICMPv6 reply to peer: %v", err)
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(fullPacket)
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
pingBin = "ping"
|
||||||
|
ping6Bin = "ping6"
|
||||||
|
)
|
||||||
|
|
||||||
// buildPingCommand creates a platform-specific ping command.
|
// buildPingCommand creates a platform-specific ping command.
|
||||||
func buildPingCommand(ctx context.Context, target net.IP, timeout time.Duration) *exec.Cmd {
|
// Most platforms auto-detect IPv6 from raw addresses. macOS/iOS/OpenBSD require ping6.
|
||||||
|
func buildPingCommand(ctx context.Context, target netip.Addr, timeout time.Duration) *exec.Cmd {
|
||||||
timeoutSec := int(timeout.Seconds())
|
timeoutSec := int(timeout.Seconds())
|
||||||
if timeoutSec < 1 {
|
if timeoutSec < 1 {
|
||||||
timeoutSec = 1
|
timeoutSec = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
isV6 := target.Is6()
|
||||||
|
timeoutStr := fmt.Sprintf("%d", timeoutSec)
|
||||||
|
|
||||||
switch runtime.GOOS {
|
switch runtime.GOOS {
|
||||||
case "linux", "android":
|
case "linux", "android":
|
||||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-W", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
|
return exec.CommandContext(ctx, pingBin, "-c", "1", "-W", timeoutStr, "-q", target.String())
|
||||||
case "darwin", "ios":
|
case "darwin", "ios":
|
||||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
|
bin := pingBin
|
||||||
|
if isV6 {
|
||||||
|
bin = ping6Bin
|
||||||
|
}
|
||||||
|
return exec.CommandContext(ctx, bin, "-c", "1", "-t", timeoutStr, "-q", target.String())
|
||||||
case "freebsd":
|
case "freebsd":
|
||||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), target.String())
|
return exec.CommandContext(ctx, pingBin, "-c", "1", "-t", timeoutStr, target.String())
|
||||||
case "openbsd", "netbsd":
|
case "openbsd", "netbsd":
|
||||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-w", fmt.Sprintf("%d", timeoutSec), target.String())
|
bin := pingBin
|
||||||
|
if isV6 {
|
||||||
|
bin = ping6Bin
|
||||||
|
}
|
||||||
|
return exec.CommandContext(ctx, bin, "-c", "1", "-w", timeoutStr, target.String())
|
||||||
case "windows":
|
case "windows":
|
||||||
return exec.CommandContext(ctx, "ping", "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String())
|
return exec.CommandContext(ctx, pingBin, "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String())
|
||||||
default:
|
default:
|
||||||
return exec.CommandContext(ctx, "ping", "-c", "1", target.String())
|
return exec.CommandContext(ctx, pingBin, "-c", "1", target.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,10 +2,9 @@ package forwarder
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@@ -33,7 +32,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
dialAddr := net.JoinHostPort(f.determineDialAddr(id.LocalAddress).String(), strconv.Itoa(int(id.LocalPort)))
|
||||||
|
|
||||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -133,15 +132,14 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
|
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
|
||||||
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
|
srcIp := addrToNetipAddr(id.RemoteAddress)
|
||||||
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
|
dstIp := addrToNetipAddr(id.LocalAddress)
|
||||||
|
|
||||||
fields := nftypes.EventFields{
|
fields := nftypes.EventFields{
|
||||||
FlowID: flowID,
|
FlowID: flowID,
|
||||||
Type: typ,
|
Type: typ,
|
||||||
Direction: nftypes.Ingress,
|
Direction: nftypes.Ingress,
|
||||||
Protocol: nftypes.TCP,
|
Protocol: nftypes.TCP,
|
||||||
// TODO: handle ipv6
|
|
||||||
SourceIP: srcIp,
|
SourceIP: srcIp,
|
||||||
DestIP: dstIp,
|
DestIP: dstIp,
|
||||||
SourcePort: id.RemotePort,
|
SourcePort: id.RemotePort,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@@ -158,7 +158,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
dstAddr := net.JoinHostPort(f.determineDialAddr(id.LocalAddress).String(), strconv.Itoa(int(id.LocalPort)))
|
||||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.logger.Debug2("forwarder: UDP dial error for %v: %v", epID(id), err)
|
f.logger.Debug2("forwarder: UDP dial error for %v: %v", epID(id), err)
|
||||||
@@ -276,15 +276,14 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
|||||||
|
|
||||||
// sendUDPEvent stores flow events for UDP connections
|
// sendUDPEvent stores flow events for UDP connections
|
||||||
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
|
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
|
||||||
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
|
srcIp := addrToNetipAddr(id.RemoteAddress)
|
||||||
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
|
dstIp := addrToNetipAddr(id.LocalAddress)
|
||||||
|
|
||||||
fields := nftypes.EventFields{
|
fields := nftypes.EventFields{
|
||||||
FlowID: flowID,
|
FlowID: flowID,
|
||||||
Type: typ,
|
Type: typ,
|
||||||
Direction: nftypes.Ingress,
|
Direction: nftypes.Ingress,
|
||||||
Protocol: nftypes.UDP,
|
Protocol: nftypes.UDP,
|
||||||
// TODO: handle ipv6
|
|
||||||
SourceIP: srcIp,
|
SourceIP: srcIp,
|
||||||
DestIP: dstIp,
|
DestIP: dstIp,
|
||||||
SourcePort: id.RemotePort,
|
SourcePort: id.RemotePort,
|
||||||
|
|||||||
@@ -4,89 +4,32 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync/atomic"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
type localIPManager struct {
|
// localIPSnapshot is an immutable snapshot of local IP addresses, swapped
|
||||||
mu sync.RWMutex
|
// atomically so reads are lock-free.
|
||||||
|
type localIPSnapshot struct {
|
||||||
// fixed-size high array for upper byte of a IPv4 address
|
ips map[netip.Addr]struct{}
|
||||||
ipv4Bitmap [256]*ipv4LowBitmap
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ipv4LowBitmap is a map for the low 16 bits of a IPv4 address
|
type localIPManager struct {
|
||||||
type ipv4LowBitmap struct {
|
snapshot atomic.Pointer[localIPSnapshot]
|
||||||
bitmap [8192]uint32
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newLocalIPManager() *localIPManager {
|
func newLocalIPManager() *localIPManager {
|
||||||
return &localIPManager{}
|
m := &localIPManager{}
|
||||||
|
m.snapshot.Store(&localIPSnapshot{
|
||||||
|
ips: make(map[netip.Addr]struct{}),
|
||||||
|
})
|
||||||
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *localIPManager) setBitmapBit(ip net.IP) {
|
func processInterface(iface net.Interface, ips map[netip.Addr]struct{}, addresses *[]netip.Addr) {
|
||||||
ipv4 := ip.To4()
|
|
||||||
if ipv4 == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
high := uint16(ipv4[0])
|
|
||||||
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
|
||||||
|
|
||||||
index := low / 32
|
|
||||||
bit := low % 32
|
|
||||||
|
|
||||||
if m.ipv4Bitmap[high] == nil {
|
|
||||||
m.ipv4Bitmap[high] = &ipv4LowBitmap{}
|
|
||||||
}
|
|
||||||
|
|
||||||
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *localIPManager) setBitInBitmap(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
|
|
||||||
if !ip.Is4() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ipv4 := ip.AsSlice()
|
|
||||||
|
|
||||||
high := uint16(ipv4[0])
|
|
||||||
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
|
||||||
|
|
||||||
if bitmap[high] == nil {
|
|
||||||
bitmap[high] = &ipv4LowBitmap{}
|
|
||||||
}
|
|
||||||
|
|
||||||
index := low / 32
|
|
||||||
bit := low % 32
|
|
||||||
bitmap[high].bitmap[index] |= 1 << bit
|
|
||||||
|
|
||||||
if _, exists := ipv4Set[ip]; !exists {
|
|
||||||
ipv4Set[ip] = struct{}{}
|
|
||||||
*ipv4Addresses = append(*ipv4Addresses, ip)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
|
|
||||||
high := uint16(ip[0])
|
|
||||||
low := (uint16(ip[1]) << 8) | (uint16(ip[2]) << 4) | uint16(ip[3])
|
|
||||||
|
|
||||||
if m.ipv4Bitmap[high] == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
index := low / 32
|
|
||||||
bit := low % 32
|
|
||||||
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *localIPManager) processIP(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) error {
|
|
||||||
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
|
|
||||||
addrs, err := iface.Addrs()
|
addrs, err := iface.Addrs()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
||||||
@@ -104,18 +47,19 @@ func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
addr, ok := netip.AddrFromSlice(ip)
|
parsed, ok := netip.AddrFromSlice(ip)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name)
|
log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.processIP(addr.Unmap(), bitmap, ipv4Set, ipv4Addresses); err != nil {
|
parsed = parsed.Unmap()
|
||||||
log.Debugf("process IP failed: %v", err)
|
ips[parsed] = struct{}{}
|
||||||
}
|
*addresses = append(*addresses, parsed)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateLocalIPs rebuilds the local IP snapshot and swaps it in atomically.
|
||||||
func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
@@ -123,20 +67,20 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var newIPv4Bitmap [256]*ipv4LowBitmap
|
ips := make(map[netip.Addr]struct{})
|
||||||
ipv4Set := make(map[netip.Addr]struct{})
|
var addresses []netip.Addr
|
||||||
var ipv4Addresses []netip.Addr
|
|
||||||
|
|
||||||
// 127.0.0.0/8
|
// loopback
|
||||||
newIPv4Bitmap[127] = &ipv4LowBitmap{}
|
ips[netip.AddrFrom4([4]byte{127, 0, 0, 1})] = struct{}{}
|
||||||
for i := 0; i < 8192; i++ {
|
ips[netip.IPv6Loopback()] = struct{}{}
|
||||||
// #nosec G602 -- bitmap is defined as [8192]uint32, loop range is correct
|
|
||||||
newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF
|
|
||||||
}
|
|
||||||
|
|
||||||
if iface != nil {
|
if iface != nil {
|
||||||
if err := m.processIP(iface.Address().IP, &newIPv4Bitmap, ipv4Set, &ipv4Addresses); err != nil {
|
ip := iface.Address().IP
|
||||||
return err
|
ips[ip] = struct{}{}
|
||||||
|
addresses = append(addresses, ip)
|
||||||
|
if v6 := iface.Address().IPv6; v6.IsValid() {
|
||||||
|
ips[v6] = struct{}{}
|
||||||
|
addresses = append(addresses, v6)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -147,25 +91,24 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
|||||||
// TODO: filter out down interfaces (net.FlagUp). Also handle the reverse
|
// TODO: filter out down interfaces (net.FlagUp). Also handle the reverse
|
||||||
// case where an interface comes up between refreshes.
|
// case where an interface comes up between refreshes.
|
||||||
for _, intf := range interfaces {
|
for _, intf := range interfaces {
|
||||||
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
|
processInterface(intf, ips, &addresses)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
m.mu.Lock()
|
m.snapshot.Store(&localIPSnapshot{ips: ips})
|
||||||
m.ipv4Bitmap = newIPv4Bitmap
|
|
||||||
m.mu.Unlock()
|
|
||||||
|
|
||||||
log.Debugf("Local IPv4 addresses: %v", ipv4Addresses)
|
log.Debugf("Local IP addresses: %v", addresses)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsLocalIP checks if the given IP is a local address. Lock-free on the read path.
|
||||||
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
|
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
|
||||||
if !ip.Is4() {
|
s := m.snapshot.Load()
|
||||||
return false
|
|
||||||
|
if ip.Is4() && ip.As4()[0] == 127 {
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
m.mu.RLock()
|
_, found := s.ips[ip]
|
||||||
defer m.mu.RUnlock()
|
return found
|
||||||
|
|
||||||
return m.checkBitmapBit(ip.AsSlice())
|
|
||||||
}
|
}
|
||||||
|
|||||||
72
client/firewall/uspfilter/localip_bench_test.go
Normal file
72
client/firewall/uspfilter/localip_bench_test.go
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupManager(b *testing.B) *localIPManager {
|
||||||
|
b.Helper()
|
||||||
|
m := newLocalIPManager()
|
||||||
|
mock := &IFaceMock{
|
||||||
|
AddressFunc: func() wgaddr.Address {
|
||||||
|
return wgaddr.Address{
|
||||||
|
IP: netip.MustParseAddr("100.64.0.1"),
|
||||||
|
Network: netip.MustParsePrefix("100.64.0.0/16"),
|
||||||
|
IPv6: netip.MustParseAddr("fd00::1"),
|
||||||
|
IPv6Net: netip.MustParsePrefix("fd00::/64"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if err := m.UpdateLocalIPs(mock); err != nil {
|
||||||
|
b.Fatalf("UpdateLocalIPs: %v", err)
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkIsLocalIP_v4_hit(b *testing.B) {
|
||||||
|
m := setupManager(b)
|
||||||
|
ip := netip.MustParseAddr("100.64.0.1")
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
m.IsLocalIP(ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkIsLocalIP_v4_miss(b *testing.B) {
|
||||||
|
m := setupManager(b)
|
||||||
|
ip := netip.MustParseAddr("8.8.8.8")
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
m.IsLocalIP(ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkIsLocalIP_v6_hit(b *testing.B) {
|
||||||
|
m := setupManager(b)
|
||||||
|
ip := netip.MustParseAddr("fd00::1")
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
m.IsLocalIP(ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkIsLocalIP_v6_miss(b *testing.B) {
|
||||||
|
m := setupManager(b)
|
||||||
|
ip := netip.MustParseAddr("2001:db8::1")
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
m.IsLocalIP(ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkIsLocalIP_loopback(b *testing.B) {
|
||||||
|
m := setupManager(b)
|
||||||
|
ip := netip.MustParseAddr("127.0.0.1")
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
m.IsLocalIP(ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -72,14 +72,45 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "IPv6 address",
|
name: "IPv6 address matches",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: netip.MustParseAddr("fe80::1"),
|
IP: netip.MustParseAddr("100.64.0.1"),
|
||||||
|
Network: netip.MustParsePrefix("100.64.0.0/16"),
|
||||||
|
IPv6: netip.MustParseAddr("fd00::1"),
|
||||||
|
IPv6Net: netip.MustParsePrefix("fd00::/64"),
|
||||||
|
},
|
||||||
|
testIP: netip.MustParseAddr("fd00::1"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 address does not match",
|
||||||
|
setupAddr: wgaddr.Address{
|
||||||
|
IP: netip.MustParseAddr("100.64.0.1"),
|
||||||
|
Network: netip.MustParsePrefix("100.64.0.0/16"),
|
||||||
|
IPv6: netip.MustParseAddr("fd00::1"),
|
||||||
|
IPv6Net: netip.MustParsePrefix("fd00::/64"),
|
||||||
|
},
|
||||||
|
testIP: netip.MustParseAddr("fd00::99"),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No aliasing between similar IPs",
|
||||||
|
setupAddr: wgaddr.Address{
|
||||||
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("fe80::1"),
|
testIP: netip.MustParseAddr("192.168.0.17"),
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 loopback",
|
||||||
|
setupAddr: wgaddr.Address{
|
||||||
|
IP: netip.MustParseAddr("100.64.0.1"),
|
||||||
|
Network: netip.MustParsePrefix("100.64.0.0/16"),
|
||||||
|
},
|
||||||
|
testIP: netip.MustParseAddr("::1"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -171,90 +202,3 @@ func TestLocalIPManager_AllInterfaces(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MapImplementation is a version using map[string]struct{}
|
|
||||||
type MapImplementation struct {
|
|
||||||
localIPs map[string]struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkIPChecks(b *testing.B) {
|
|
||||||
interfaces := make([]net.IP, 16)
|
|
||||||
for i := range interfaces {
|
|
||||||
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Setup bitmap
|
|
||||||
bitmapManager := newLocalIPManager()
|
|
||||||
for _, ip := range interfaces[:8] { // Add half of IPs
|
|
||||||
bitmapManager.setBitmapBit(ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Setup map version
|
|
||||||
mapManager := &MapImplementation{
|
|
||||||
localIPs: make(map[string]struct{}),
|
|
||||||
}
|
|
||||||
for _, ip := range interfaces[:8] {
|
|
||||||
mapManager.localIPs[ip.String()] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
b.Run("Bitmap_Hit", func(b *testing.B) {
|
|
||||||
ip := interfaces[4]
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
bitmapManager.checkBitmapBit(ip)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("Bitmap_Miss", func(b *testing.B) {
|
|
||||||
ip := interfaces[12]
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
bitmapManager.checkBitmapBit(ip)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("Map_Hit", func(b *testing.B) {
|
|
||||||
ip := interfaces[4]
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
// nolint:gosimple
|
|
||||||
_ = mapManager.localIPs[ip.String()]
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("Map_Miss", func(b *testing.B) {
|
|
||||||
ip := interfaces[12]
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
// nolint:gosimple
|
|
||||||
_ = mapManager.localIPs[ip.String()]
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkWGPosition(b *testing.B) {
|
|
||||||
wgIP := net.ParseIP("10.10.0.1")
|
|
||||||
|
|
||||||
// Create two managers - one checks WG IP first, other checks it last
|
|
||||||
b.Run("WG_First", func(b *testing.B) {
|
|
||||||
bm := newLocalIPManager()
|
|
||||||
bm.setBitmapBit(wgIP)
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
bm.checkBitmapBit(wgIP)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("WG_Last", func(b *testing.B) {
|
|
||||||
bm := newLocalIPManager()
|
|
||||||
// Fill with other IPs first
|
|
||||||
for i := 0; i < 15; i++ {
|
|
||||||
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))
|
|
||||||
}
|
|
||||||
bm.setBitmapBit(wgIP) // Add WG IP last
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
bm.checkBitmapBit(wgIP)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -13,8 +13,6 @@ import (
|
|||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errInvalidIPHeaderLength = errors.New("invalid IP header length")
|
errInvalidIPHeaderLength = errors.New("invalid IP header length")
|
||||||
)
|
)
|
||||||
@@ -25,10 +23,33 @@ const (
|
|||||||
destinationPortOffset = 2
|
destinationPortOffset = 2
|
||||||
|
|
||||||
// IP address offsets in IPv4 header
|
// IP address offsets in IPv4 header
|
||||||
sourceIPOffset = 12
|
ipv4SrcOffset = 12
|
||||||
destinationIPOffset = 16
|
ipv4DstOffset = 16
|
||||||
|
|
||||||
|
// IP address offsets in IPv6 header
|
||||||
|
ipv6SrcOffset = 8
|
||||||
|
ipv6DstOffset = 24
|
||||||
|
|
||||||
|
// IPv6 fixed header length
|
||||||
|
ipv6HeaderLen = 40
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ipHeaderLen returns the IP header length based on the decoded layer type.
|
||||||
|
func ipHeaderLen(d *decoder) (int, error) {
|
||||||
|
switch d.decoded[0] {
|
||||||
|
case layers.LayerTypeIPv4:
|
||||||
|
n := int(d.ip4.IHL) * 4
|
||||||
|
if n < 20 {
|
||||||
|
return 0, errInvalidIPHeaderLength
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
case layers.LayerTypeIPv6:
|
||||||
|
return ipv6HeaderLen, nil
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("unknown IP layer: %v", d.decoded[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ipv4Checksum calculates IPv4 header checksum.
|
// ipv4Checksum calculates IPv4 header checksum.
|
||||||
func ipv4Checksum(header []byte) uint16 {
|
func ipv4Checksum(header []byte) uint16 {
|
||||||
if len(header) < 20 {
|
if len(header) < 20 {
|
||||||
@@ -234,14 +255,13 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
_, dstIP := extractPacketIPs(packetData, d)
|
||||||
|
|
||||||
translatedIP, exists := m.getDNATTranslation(dstIP)
|
translatedIP, exists := m.getDNATTranslation(dstIP)
|
||||||
if !exists {
|
if !exists {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil {
|
if err := m.rewritePacketIP(packetData, d, translatedIP, false); err != nil {
|
||||||
m.logger.Error1("failed to rewrite packet destination: %v", err)
|
m.logger.Error1("failed to rewrite packet destination: %v", err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -256,14 +276,13 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
srcIP, _ := extractPacketIPs(packetData, d)
|
||||||
|
|
||||||
originalIP, exists := m.findReverseDNATMapping(srcIP)
|
originalIP, exists := m.findReverseDNATMapping(srcIP)
|
||||||
if !exists {
|
if !exists {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil {
|
if err := m.rewritePacketIP(packetData, d, originalIP, true); err != nil {
|
||||||
m.logger.Error1("failed to rewrite packet source: %v", err)
|
m.logger.Error1("failed to rewrite packet source: %v", err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -272,38 +291,96 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// rewritePacketIP replaces an IP address (source or destination) in the packet and updates checksums.
|
// extractPacketIPs extracts src and dst IP addresses directly from raw packet bytes.
|
||||||
func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, ipOffset int) error {
|
func extractPacketIPs(packetData []byte, d *decoder) (src, dst netip.Addr) {
|
||||||
|
switch d.decoded[0] {
|
||||||
|
case layers.LayerTypeIPv4:
|
||||||
|
src = netip.AddrFrom4([4]byte{packetData[ipv4SrcOffset], packetData[ipv4SrcOffset+1], packetData[ipv4SrcOffset+2], packetData[ipv4SrcOffset+3]})
|
||||||
|
dst = netip.AddrFrom4([4]byte{packetData[ipv4DstOffset], packetData[ipv4DstOffset+1], packetData[ipv4DstOffset+2], packetData[ipv4DstOffset+3]})
|
||||||
|
case layers.LayerTypeIPv6:
|
||||||
|
src = netip.AddrFrom16([16]byte(packetData[ipv6SrcOffset : ipv6SrcOffset+16]))
|
||||||
|
dst = netip.AddrFrom16([16]byte(packetData[ipv6DstOffset : ipv6DstOffset+16]))
|
||||||
|
}
|
||||||
|
return src, dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// rewritePacketIP replaces a source (isSource=true) or destination IP address in the packet and updates checksums.
|
||||||
|
func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, isSource bool) error {
|
||||||
|
hdrLen, err := ipHeaderLen(d)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch d.decoded[0] {
|
||||||
|
case layers.LayerTypeIPv4:
|
||||||
|
return m.rewriteIPv4(packetData, d, newIP, hdrLen, isSource)
|
||||||
|
case layers.LayerTypeIPv6:
|
||||||
|
return m.rewriteIPv6(packetData, d, newIP, hdrLen, isSource)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown IP layer: %v", d.decoded[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) rewriteIPv4(packetData []byte, d *decoder, newIP netip.Addr, hdrLen int, isSource bool) error {
|
||||||
if !newIP.Is4() {
|
if !newIP.Is4() {
|
||||||
return ErrIPv4Only
|
return fmt.Errorf("cannot write IPv6 address into IPv4 packet")
|
||||||
|
}
|
||||||
|
|
||||||
|
offset := ipv4DstOffset
|
||||||
|
if isSource {
|
||||||
|
offset = ipv4SrcOffset
|
||||||
}
|
}
|
||||||
|
|
||||||
var oldIP [4]byte
|
var oldIP [4]byte
|
||||||
copy(oldIP[:], packetData[ipOffset:ipOffset+4])
|
copy(oldIP[:], packetData[offset:offset+4])
|
||||||
newIPBytes := newIP.As4()
|
newIPBytes := newIP.As4()
|
||||||
|
copy(packetData[offset:offset+4], newIPBytes[:])
|
||||||
|
|
||||||
copy(packetData[ipOffset:ipOffset+4], newIPBytes[:])
|
// Recalculate IPv4 header checksum
|
||||||
|
|
||||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
|
||||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
|
||||||
return errInvalidIPHeaderLength
|
|
||||||
}
|
|
||||||
|
|
||||||
binary.BigEndian.PutUint16(packetData[10:12], 0)
|
binary.BigEndian.PutUint16(packetData[10:12], 0)
|
||||||
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
|
binary.BigEndian.PutUint16(packetData[10:12], ipv4Checksum(packetData[:hdrLen]))
|
||||||
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
|
|
||||||
|
|
||||||
|
// Update transport checksums incrementally
|
||||||
if len(d.decoded) > 1 {
|
if len(d.decoded) > 1 {
|
||||||
switch d.decoded[1] {
|
switch d.decoded[1] {
|
||||||
case layers.LayerTypeTCP:
|
case layers.LayerTypeTCP:
|
||||||
m.updateTCPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
|
m.updateTCPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
|
||||||
case layers.LayerTypeUDP:
|
case layers.LayerTypeUDP:
|
||||||
m.updateUDPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
|
m.updateUDPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
|
||||||
case layers.LayerTypeICMPv4:
|
case layers.LayerTypeICMPv4:
|
||||||
m.updateICMPChecksum(packetData, ipHeaderLen)
|
m.updateICMPChecksum(packetData, hdrLen)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) rewriteIPv6(packetData []byte, d *decoder, newIP netip.Addr, hdrLen int, isSource bool) error {
|
||||||
|
if !newIP.Is6() {
|
||||||
|
return fmt.Errorf("cannot write IPv4 address into IPv6 packet")
|
||||||
|
}
|
||||||
|
|
||||||
|
offset := ipv6DstOffset
|
||||||
|
if isSource {
|
||||||
|
offset = ipv6SrcOffset
|
||||||
|
}
|
||||||
|
|
||||||
|
var oldIP [16]byte
|
||||||
|
copy(oldIP[:], packetData[offset:offset+16])
|
||||||
|
newIPBytes := newIP.As16()
|
||||||
|
copy(packetData[offset:offset+16], newIPBytes[:])
|
||||||
|
|
||||||
|
// IPv6 has no header checksum, only update transport checksums
|
||||||
|
if len(d.decoded) > 1 {
|
||||||
|
switch d.decoded[1] {
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
m.updateTCPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
|
||||||
|
case layers.LayerTypeUDP:
|
||||||
|
m.updateUDPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
|
||||||
|
case layers.LayerTypeICMPv6:
|
||||||
|
// ICMPv6 checksum includes pseudo-header with addresses, use incremental update
|
||||||
|
m.updateICMPv6Checksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
|
||||||
|
}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -351,6 +428,20 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
|
|||||||
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
|
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// updateICMPv6Checksum updates ICMPv6 checksum after address change.
|
||||||
|
// ICMPv6 uses a pseudo-header (like TCP/UDP), so incremental update applies.
|
||||||
|
func (m *Manager) updateICMPv6Checksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
|
||||||
|
icmpStart := ipHeaderLen
|
||||||
|
if len(packetData) < icmpStart+4 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
checksumOffset := icmpStart + 2
|
||||||
|
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
|
||||||
|
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
|
||||||
|
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||||
|
}
|
||||||
|
|
||||||
// incrementalUpdate performs incremental checksum update per RFC 1624.
|
// incrementalUpdate performs incremental checksum update per RFC 1624.
|
||||||
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
||||||
sum := uint32(^oldChecksum)
|
sum := uint32(^oldChecksum)
|
||||||
@@ -532,12 +623,12 @@ func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP neti
|
|||||||
|
|
||||||
// rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum.
|
// rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum.
|
||||||
func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
|
func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
|
||||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
hdrLen, err := ipHeaderLen(d)
|
||||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
if err != nil {
|
||||||
return errInvalidIPHeaderLength
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
tcpStart := ipHeaderLen
|
tcpStart := hdrLen
|
||||||
if len(packetData) < tcpStart+4 {
|
if len(packetData) < tcpStart+4 {
|
||||||
return fmt.Errorf("packet too short for TCP header")
|
return fmt.Errorf("packet too short for TCP header")
|
||||||
}
|
}
|
||||||
@@ -563,12 +654,12 @@ func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16,
|
|||||||
|
|
||||||
// rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum.
|
// rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum.
|
||||||
func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
|
func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
|
||||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
hdrLen, err := ipHeaderLen(d)
|
||||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
if err != nil {
|
||||||
return errInvalidIPHeaderLength
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
udpStart := ipHeaderLen
|
udpStart := hdrLen
|
||||||
if len(packetData) < udpStart+8 {
|
if len(packetData) < udpStart+8 {
|
||||||
return fmt.Errorf("packet too short for UDP header")
|
return fmt.Errorf("packet too short for UDP header")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -342,12 +342,17 @@ func BenchmarkDNATMemoryAllocations(b *testing.B) {
|
|||||||
|
|
||||||
// Parse the packet fresh each time to get a clean decoder
|
// Parse the packet fresh each time to get a clean decoder
|
||||||
d := &decoder{decoded: []gopacket.LayerType{}}
|
d := &decoder{decoded: []gopacket.LayerType{}}
|
||||||
d.parser = gopacket.NewDecodingLayerParser(
|
d.parser4 = gopacket.NewDecodingLayerParser(
|
||||||
layers.LayerTypeIPv4,
|
layers.LayerTypeIPv4,
|
||||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
)
|
)
|
||||||
d.parser.IgnoreUnsupported = true
|
d.parser4.IgnoreUnsupported = true
|
||||||
err = d.parser.DecodeLayers(testPacket, &d.decoded)
|
d.parser6 = gopacket.NewDecodingLayerParser(
|
||||||
|
layers.LayerTypeIPv6,
|
||||||
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
|
)
|
||||||
|
d.parser6.IgnoreUnsupported = true
|
||||||
|
err = d.decodePacket(testPacket)
|
||||||
assert.NoError(b, err)
|
assert.NoError(b, err)
|
||||||
|
|
||||||
manager.translateOutboundDNAT(testPacket, d)
|
manager.translateOutboundDNAT(testPacket, d)
|
||||||
@@ -371,12 +376,17 @@ func BenchmarkDirectIPExtraction(b *testing.B) {
|
|||||||
b.Run("decoder_extraction", func(b *testing.B) {
|
b.Run("decoder_extraction", func(b *testing.B) {
|
||||||
// Create decoder once for comparison
|
// Create decoder once for comparison
|
||||||
d := &decoder{decoded: []gopacket.LayerType{}}
|
d := &decoder{decoded: []gopacket.LayerType{}}
|
||||||
d.parser = gopacket.NewDecodingLayerParser(
|
d.parser4 = gopacket.NewDecodingLayerParser(
|
||||||
layers.LayerTypeIPv4,
|
layers.LayerTypeIPv4,
|
||||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
)
|
)
|
||||||
d.parser.IgnoreUnsupported = true
|
d.parser4.IgnoreUnsupported = true
|
||||||
err := d.parser.DecodeLayers(packet, &d.decoded)
|
d.parser6 = gopacket.NewDecodingLayerParser(
|
||||||
|
layers.LayerTypeIPv6,
|
||||||
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
|
)
|
||||||
|
d.parser6.IgnoreUnsupported = true
|
||||||
|
err := d.decodePacket(packet)
|
||||||
assert.NoError(b, err)
|
assert.NoError(b, err)
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
|
|||||||
@@ -86,13 +86,18 @@ func parsePacket(t testing.TB, packetData []byte) *decoder {
|
|||||||
d := &decoder{
|
d := &decoder{
|
||||||
decoded: []gopacket.LayerType{},
|
decoded: []gopacket.LayerType{},
|
||||||
}
|
}
|
||||||
d.parser = gopacket.NewDecodingLayerParser(
|
d.parser4 = gopacket.NewDecodingLayerParser(
|
||||||
layers.LayerTypeIPv4,
|
layers.LayerTypeIPv4,
|
||||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
)
|
)
|
||||||
d.parser.IgnoreUnsupported = true
|
d.parser4.IgnoreUnsupported = true
|
||||||
|
d.parser6 = gopacket.NewDecodingLayerParser(
|
||||||
|
layers.LayerTypeIPv6,
|
||||||
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
|
)
|
||||||
|
d.parser6.IgnoreUnsupported = true
|
||||||
|
|
||||||
err := d.parser.DecodeLayers(packetData, &d.decoded)
|
err := d.decodePacket(packetData)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -112,10 +112,13 @@ func (t *PacketTrace) AddResultWithForwarder(stage PacketStage, message string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *PacketBuilder) Build() ([]byte, error) {
|
func (p *PacketBuilder) Build() ([]byte, error) {
|
||||||
ip := p.buildIPLayer()
|
ipLayer, err := p.buildIPLayer()
|
||||||
pktLayers := []gopacket.SerializableLayer{ip}
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
pktLayers := []gopacket.SerializableLayer{ipLayer}
|
||||||
|
|
||||||
transportLayer, err := p.buildTransportLayer(ip)
|
transportLayer, err := p.buildTransportLayer(ipLayer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -129,30 +132,43 @@ func (p *PacketBuilder) Build() ([]byte, error) {
|
|||||||
return serializePacket(pktLayers)
|
return serializePacket(pktLayers)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
|
func (p *PacketBuilder) buildIPLayer() (gopacket.SerializableLayer, error) {
|
||||||
|
if p.SrcIP.Is4() != p.DstIP.Is4() {
|
||||||
|
return nil, fmt.Errorf("mixed address families: src=%s dst=%s", p.SrcIP, p.DstIP)
|
||||||
|
}
|
||||||
|
proto := getIPProtocolNumber(p.Protocol, p.SrcIP.Is6())
|
||||||
|
if p.SrcIP.Is6() {
|
||||||
|
return &layers.IPv6{
|
||||||
|
Version: 6,
|
||||||
|
HopLimit: 64,
|
||||||
|
NextHeader: proto,
|
||||||
|
SrcIP: p.SrcIP.AsSlice(),
|
||||||
|
DstIP: p.DstIP.AsSlice(),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
return &layers.IPv4{
|
return &layers.IPv4{
|
||||||
Version: 4,
|
Version: 4,
|
||||||
TTL: 64,
|
TTL: 64,
|
||||||
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
|
Protocol: proto,
|
||||||
SrcIP: p.SrcIP.AsSlice(),
|
SrcIP: p.SrcIP.AsSlice(),
|
||||||
DstIP: p.DstIP.AsSlice(),
|
DstIP: p.DstIP.AsSlice(),
|
||||||
}
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PacketBuilder) buildTransportLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
func (p *PacketBuilder) buildTransportLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
|
||||||
switch p.Protocol {
|
switch p.Protocol {
|
||||||
case "tcp":
|
case "tcp":
|
||||||
return p.buildTCPLayer(ip)
|
return p.buildTCPLayer(ipLayer)
|
||||||
case "udp":
|
case "udp":
|
||||||
return p.buildUDPLayer(ip)
|
return p.buildUDPLayer(ipLayer)
|
||||||
case "icmp":
|
case "icmp":
|
||||||
return p.buildICMPLayer()
|
return p.buildICMPLayer(ipLayer)
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unsupported protocol: %s", p.Protocol)
|
return nil, fmt.Errorf("unsupported protocol: %s", p.Protocol)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PacketBuilder) buildTCPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
func (p *PacketBuilder) buildTCPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
|
||||||
tcp := &layers.TCP{
|
tcp := &layers.TCP{
|
||||||
SrcPort: layers.TCPPort(p.SrcPort),
|
SrcPort: layers.TCPPort(p.SrcPort),
|
||||||
DstPort: layers.TCPPort(p.DstPort),
|
DstPort: layers.TCPPort(p.DstPort),
|
||||||
@@ -164,24 +180,44 @@ func (p *PacketBuilder) buildTCPLayer(ip *layers.IPv4) ([]gopacket.SerializableL
|
|||||||
PSH: p.TCPState != nil && p.TCPState.PSH,
|
PSH: p.TCPState != nil && p.TCPState.PSH,
|
||||||
URG: p.TCPState != nil && p.TCPState.URG,
|
URG: p.TCPState != nil && p.TCPState.URG,
|
||||||
}
|
}
|
||||||
if err := tcp.SetNetworkLayerForChecksum(ip); err != nil {
|
if nl, ok := ipLayer.(gopacket.NetworkLayer); ok {
|
||||||
return nil, fmt.Errorf("set network layer for TCP checksum: %w", err)
|
if err := tcp.SetNetworkLayerForChecksum(nl); err != nil {
|
||||||
|
return nil, fmt.Errorf("set network layer for TCP checksum: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return []gopacket.SerializableLayer{tcp}, nil
|
return []gopacket.SerializableLayer{tcp}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PacketBuilder) buildUDPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
func (p *PacketBuilder) buildUDPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
|
||||||
udp := &layers.UDP{
|
udp := &layers.UDP{
|
||||||
SrcPort: layers.UDPPort(p.SrcPort),
|
SrcPort: layers.UDPPort(p.SrcPort),
|
||||||
DstPort: layers.UDPPort(p.DstPort),
|
DstPort: layers.UDPPort(p.DstPort),
|
||||||
}
|
}
|
||||||
if err := udp.SetNetworkLayerForChecksum(ip); err != nil {
|
if nl, ok := ipLayer.(gopacket.NetworkLayer); ok {
|
||||||
return nil, fmt.Errorf("set network layer for UDP checksum: %w", err)
|
if err := udp.SetNetworkLayerForChecksum(nl); err != nil {
|
||||||
|
return nil, fmt.Errorf("set network layer for UDP checksum: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return []gopacket.SerializableLayer{udp}, nil
|
return []gopacket.SerializableLayer{udp}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PacketBuilder) buildICMPLayer() ([]gopacket.SerializableLayer, error) {
|
func (p *PacketBuilder) buildICMPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
|
||||||
|
if p.SrcIP.Is6() || p.DstIP.Is6() {
|
||||||
|
icmp := &layers.ICMPv6{
|
||||||
|
TypeCode: layers.CreateICMPv6TypeCode(p.ICMPType, p.ICMPCode),
|
||||||
|
}
|
||||||
|
if nl, ok := ipLayer.(gopacket.NetworkLayer); ok {
|
||||||
|
_ = icmp.SetNetworkLayerForChecksum(nl)
|
||||||
|
}
|
||||||
|
if p.ICMPType == layers.ICMPv6TypeEchoRequest || p.ICMPType == layers.ICMPv6TypeEchoReply {
|
||||||
|
echo := &layers.ICMPv6Echo{
|
||||||
|
Identifier: 1,
|
||||||
|
SeqNumber: 1,
|
||||||
|
}
|
||||||
|
return []gopacket.SerializableLayer{icmp, echo}, nil
|
||||||
|
}
|
||||||
|
return []gopacket.SerializableLayer{icmp}, nil
|
||||||
|
}
|
||||||
icmp := &layers.ICMPv4{
|
icmp := &layers.ICMPv4{
|
||||||
TypeCode: layers.CreateICMPv4TypeCode(p.ICMPType, p.ICMPCode),
|
TypeCode: layers.CreateICMPv4TypeCode(p.ICMPType, p.ICMPCode),
|
||||||
}
|
}
|
||||||
@@ -204,14 +240,17 @@ func serializePacket(layers []gopacket.SerializableLayer) ([]byte, error) {
|
|||||||
return buf.Bytes(), nil
|
return buf.Bytes(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getIPProtocolNumber(protocol fw.Protocol) int {
|
func getIPProtocolNumber(protocol fw.Protocol, isV6 bool) layers.IPProtocol {
|
||||||
switch protocol {
|
switch protocol {
|
||||||
case fw.ProtocolTCP:
|
case fw.ProtocolTCP:
|
||||||
return int(layers.IPProtocolTCP)
|
return layers.IPProtocolTCP
|
||||||
case fw.ProtocolUDP:
|
case fw.ProtocolUDP:
|
||||||
return int(layers.IPProtocolUDP)
|
return layers.IPProtocolUDP
|
||||||
case fw.ProtocolICMP:
|
case fw.ProtocolICMP:
|
||||||
return int(layers.IPProtocolICMPv4)
|
if isV6 {
|
||||||
|
return layers.IPProtocolICMPv6
|
||||||
|
}
|
||||||
|
return layers.IPProtocolICMPv4
|
||||||
default:
|
default:
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
@@ -234,7 +273,7 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
|
|||||||
trace := &PacketTrace{Direction: direction}
|
trace := &PacketTrace{Direction: direction}
|
||||||
|
|
||||||
// Initial packet decoding
|
// Initial packet decoding
|
||||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
if err := d.decodePacket(packetData); err != nil {
|
||||||
trace.AddResult(StageReceived, fmt.Sprintf("Failed to decode packet: %v", err), false)
|
trace.AddResult(StageReceived, fmt.Sprintf("Failed to decode packet: %v", err), false)
|
||||||
return trace
|
return trace
|
||||||
}
|
}
|
||||||
@@ -256,6 +295,8 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
|
|||||||
trace.DestinationPort = uint16(d.udp.DstPort)
|
trace.DestinationPort = uint16(d.udp.DstPort)
|
||||||
case layers.LayerTypeICMPv4:
|
case layers.LayerTypeICMPv4:
|
||||||
trace.Protocol = "ICMP"
|
trace.Protocol = "ICMP"
|
||||||
|
case layers.LayerTypeICMPv6:
|
||||||
|
trace.Protocol = "ICMPv6"
|
||||||
}
|
}
|
||||||
|
|
||||||
trace.AddResult(StageReceived, fmt.Sprintf("Received %s packet: %s:%d -> %s:%d",
|
trace.AddResult(StageReceived, fmt.Sprintf("Received %s packet: %s:%d -> %s:%d",
|
||||||
@@ -319,6 +360,13 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string {
|
|||||||
flags&conntrack.TCPFin != 0)
|
flags&conntrack.TCPFin != 0)
|
||||||
case layers.LayerTypeICMPv4:
|
case layers.LayerTypeICMPv4:
|
||||||
msg += fmt.Sprintf(" (ICMP ID=%d, Seq=%d)", d.icmp4.Id, d.icmp4.Seq)
|
msg += fmt.Sprintf(" (ICMP ID=%d, Seq=%d)", d.icmp4.Id, d.icmp4.Seq)
|
||||||
|
case layers.LayerTypeICMPv6:
|
||||||
|
var id, seq uint16
|
||||||
|
if len(d.icmp6.Payload) >= 4 {
|
||||||
|
id = uint16(d.icmp6.Payload[0])<<8 | uint16(d.icmp6.Payload[1])
|
||||||
|
seq = uint16(d.icmp6.Payload[2])<<8 | uint16(d.icmp6.Payload[3])
|
||||||
|
}
|
||||||
|
msg += fmt.Sprintf(" (ICMPv6 ID=%d, Seq=%d)", id, seq)
|
||||||
}
|
}
|
||||||
return msg
|
return msg
|
||||||
}
|
}
|
||||||
@@ -415,7 +463,7 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr
|
|||||||
d := m.decoders.Get().(*decoder)
|
d := m.decoders.Get().(*decoder)
|
||||||
defer m.decoders.Put(d)
|
defer m.decoders.Put(d)
|
||||||
|
|
||||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
if err := d.decodePacket(packetData); err != nil {
|
||||||
trace.AddResult(StageCompleted, "Packet dropped - decode error", false)
|
trace.AddResult(StageCompleted, "Packet dropped - decode error", false)
|
||||||
return trace
|
return trace
|
||||||
}
|
}
|
||||||
@@ -434,7 +482,7 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr
|
|||||||
func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool {
|
func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool {
|
||||||
portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d)
|
portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d)
|
||||||
if portDNATApplied {
|
if portDNATApplied {
|
||||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
if err := d.decodePacket(packetData); err != nil {
|
||||||
trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false)
|
trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -444,7 +492,7 @@ func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *de
|
|||||||
|
|
||||||
nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d)
|
nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d)
|
||||||
if nat1to1Applied {
|
if nat1to1Applied {
|
||||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
if err := d.decodePacket(packetData); err != nil {
|
||||||
trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false)
|
trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -509,7 +557,7 @@ func (m *Manager) traceInbound1to1NAT(trace *PacketTrace, packetData []byte, d *
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
srcIP, _ := extractPacketIPs(packetData, d)
|
||||||
|
|
||||||
translated := m.translateInboundReverse(packetData, d)
|
translated := m.translateInboundReverse(packetData, d)
|
||||||
if translated {
|
if translated {
|
||||||
@@ -539,7 +587,7 @@ func (m *Manager) traceOutbound1to1NAT(trace *PacketTrace, packetData []byte, d
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
_, dstIP := extractPacketIPs(packetData, d)
|
||||||
|
|
||||||
translated := m.translateOutboundDNAT(packetData, d)
|
translated := m.translateOutboundDNAT(packetData, d)
|
||||||
if translated {
|
if translated {
|
||||||
|
|||||||
@@ -119,7 +119,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to parse endpoint address: %w", err)
|
return fmt.Errorf("failed to parse endpoint address: %w", err)
|
||||||
}
|
}
|
||||||
addrPort := netip.AddrPortFrom(addr, uint16(endpoint.Port))
|
addrPort := netip.AddrPortFrom(addr.Unmap(), uint16(endpoint.Port))
|
||||||
c.activityRecorder.UpsertAddress(peerKey, addrPort)
|
c.activityRecorder.UpsertAddress(peerKey, addrPort)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package device
|
|||||||
|
|
||||||
// TunAdapter is an interface for create tun device from external service
|
// TunAdapter is an interface for create tun device from external service
|
||||||
type TunAdapter interface {
|
type TunAdapter interface {
|
||||||
ConfigureInterface(address string, mtu int, dns string, searchDomains string, routes string) (int, error)
|
ConfigureInterface(address string, addressV6 string, mtu int, dns string, searchDomains string, routes string) (int, error)
|
||||||
UpdateAddr(address string) error
|
UpdateAddr(address string) error
|
||||||
ProtectSocket(fd int32) bool
|
ProtectSocket(fd int32) bool
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
|
|||||||
searchDomainsToString = ""
|
searchDomainsToString = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), int(t.mtu), dns, searchDomainsToString, routesString)
|
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.address.IPv6String(), int(t.mtu), dns, searchDomainsToString, routesString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to create Android interface: %s", err)
|
log.Errorf("failed to create Android interface: %s", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -196,18 +196,22 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// fakeAddress returns a fake address that is used to as an identifier for the peer.
|
// fakeAddress returns a fake address that is used as an identifier for the peer.
|
||||||
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
|
// The fake address is in the format of 127.1.x.x where x.x is derived from the
|
||||||
|
// last two bytes of the peer address (works for both IPv4 and IPv6).
|
||||||
func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
|
func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
|
||||||
octets := strings.Split(peerAddress.IP.String(), ".")
|
if peerAddress == nil {
|
||||||
if len(octets) != 4 {
|
return nil, fmt.Errorf("nil peer address")
|
||||||
return nil, fmt.Errorf("invalid IP format")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fakeIP, err := netip.ParseAddr(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3]))
|
addr, ok := netip.AddrFromSlice(peerAddress.IP)
|
||||||
if err != nil {
|
if !ok {
|
||||||
return nil, fmt.Errorf("parse new IP: %w", err)
|
return nil, fmt.Errorf("invalid IP format")
|
||||||
}
|
}
|
||||||
|
addr = addr.Unmap()
|
||||||
|
|
||||||
|
raw := addr.As16()
|
||||||
|
fakeIP := netip.AddrFrom4([4]byte{127, 1, raw[14], raw[15]})
|
||||||
|
|
||||||
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
|
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
|
||||||
return &netipAddr, nil
|
return &netipAddr, nil
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -19,6 +18,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
"github.com/netbirdio/netbird/shared/netiputil"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrSourceRangesEmpty = errors.New("sources range is empty")
|
var ErrSourceRangesEmpty = errors.New("sources range is empty")
|
||||||
@@ -105,6 +105,10 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
|||||||
newRulePairs := make(map[id.RuleID][]firewall.Rule)
|
newRulePairs := make(map[id.RuleID][]firewall.Rule)
|
||||||
ipsetByRuleSelectors := make(map[string]string)
|
ipsetByRuleSelectors := make(map[string]string)
|
||||||
|
|
||||||
|
// TODO: deny rules should be fatal: if a deny rule fails to apply, we must
|
||||||
|
// roll back all allow rules to avoid a fail-open where allowed traffic bypasses
|
||||||
|
// the missing deny. Currently we accumulate errors and continue.
|
||||||
|
var merr *multierror.Error
|
||||||
for _, r := range rules {
|
for _, r := range rules {
|
||||||
// if this rule is member of rule selection with more than DefaultIPsCountForSet
|
// if this rule is member of rule selection with more than DefaultIPsCountForSet
|
||||||
// it's IP address can be used in the ipset for firewall manager which supports it
|
// it's IP address can be used in the ipset for firewall manager which supports it
|
||||||
@@ -117,9 +121,8 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
|||||||
}
|
}
|
||||||
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
|
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
|
merr = multierror.Append(merr, fmt.Errorf("apply firewall rule: %w", err))
|
||||||
d.rollBack(newRulePairs)
|
continue
|
||||||
break
|
|
||||||
}
|
}
|
||||||
if len(rulePair) > 0 {
|
if len(rulePair) > 0 {
|
||||||
d.peerRulesPairs[pairID] = rulePair
|
d.peerRulesPairs[pairID] = rulePair
|
||||||
@@ -127,6 +130,10 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if merr != nil {
|
||||||
|
log.Errorf("failed to apply %d peer ACL rule(s): %v", merr.Len(), nberrors.FormatErrorOrNil(merr))
|
||||||
|
}
|
||||||
|
|
||||||
for pairID, rules := range d.peerRulesPairs {
|
for pairID, rules := range d.peerRulesPairs {
|
||||||
if _, ok := newRulePairs[pairID]; !ok {
|
if _, ok := newRulePairs[pairID]; !ok {
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
@@ -216,10 +223,9 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
|||||||
r *mgmProto.FirewallRule,
|
r *mgmProto.FirewallRule,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
) (id.RuleID, []firewall.Rule, error) {
|
) (id.RuleID, []firewall.Rule, error) {
|
||||||
//nolint:staticcheck // PeerIP used for backward compatibility with old management
|
ip, err := extractRuleIP(r)
|
||||||
ip := net.ParseIP(r.PeerIP)
|
if err != nil {
|
||||||
if ip == nil {
|
return "", nil, err
|
||||||
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protocol, err := convertToFirewallProtocol(r.Protocol)
|
protocol, err := convertToFirewallProtocol(r.Protocol)
|
||||||
@@ -290,13 +296,13 @@ func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
|
|||||||
|
|
||||||
func (d *DefaultManager) addInRules(
|
func (d *DefaultManager) addInRules(
|
||||||
id []byte,
|
id []byte,
|
||||||
ip net.IP,
|
ip netip.Addr,
|
||||||
protocol firewall.Protocol,
|
protocol firewall.Protocol,
|
||||||
port *firewall.Port,
|
port *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, nil, port, action, ipsetName)
|
rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, nil, port, action, ipsetName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||||
}
|
}
|
||||||
@@ -306,7 +312,7 @@ func (d *DefaultManager) addInRules(
|
|||||||
|
|
||||||
func (d *DefaultManager) addOutRules(
|
func (d *DefaultManager) addOutRules(
|
||||||
id []byte,
|
id []byte,
|
||||||
ip net.IP,
|
ip netip.Addr,
|
||||||
protocol firewall.Protocol,
|
protocol firewall.Protocol,
|
||||||
port *firewall.Port,
|
port *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
@@ -316,7 +322,7 @@ func (d *DefaultManager) addOutRules(
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, port, nil, action, ipsetName)
|
rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, port, nil, action, ipsetName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||||
}
|
}
|
||||||
@@ -324,9 +330,9 @@ func (d *DefaultManager) addOutRules(
|
|||||||
return rule, nil
|
return rule, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getPeerRuleID() returns unique ID for the rule based on its parameters.
|
// getPeerRuleID returns unique ID for the rule based on its parameters.
|
||||||
func (d *DefaultManager) getPeerRuleID(
|
func (d *DefaultManager) getPeerRuleID(
|
||||||
ip net.IP,
|
ip netip.Addr,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
direction int,
|
direction int,
|
||||||
port *firewall.Port,
|
port *firewall.Port,
|
||||||
@@ -345,15 +351,25 @@ func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) st
|
|||||||
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)
|
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) {
|
|
||||||
log.Debugf("rollback ACL to previous state")
|
// extractRuleIP extracts the peer IP from a firewall rule.
|
||||||
for _, rules := range newRulePairs {
|
// If sourcePrefixes is populated (new management), decode the first entry and use its address.
|
||||||
for _, rule := range rules {
|
// Otherwise fall back to the deprecated PeerIP string field (old management).
|
||||||
if err := d.firewall.DeletePeerRule(rule); err != nil {
|
func extractRuleIP(r *mgmProto.FirewallRule) (netip.Addr, error) {
|
||||||
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.ID(), err)
|
if len(r.SourcePrefixes) > 0 {
|
||||||
}
|
addr, err := netiputil.DecodeAddr(r.SourcePrefixes[0])
|
||||||
|
if err != nil {
|
||||||
|
return netip.Addr{}, fmt.Errorf("decode source prefix: %w", err)
|
||||||
}
|
}
|
||||||
|
return addr.Unmap(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//nolint:staticcheck // PeerIP used for backward compatibility with old management
|
||||||
|
addr, err := netip.ParseAddr(r.PeerIP)
|
||||||
|
if err != nil {
|
||||||
|
return netip.Addr{}, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||||
|
}
|
||||||
|
return addr.Unmap(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) {
|
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) {
|
||||||
|
|||||||
@@ -430,8 +430,6 @@ func isInCGNATRange(ip net.IP) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAnonymizeFirewallRules(t *testing.T) {
|
func TestAnonymizeFirewallRules(t *testing.T) {
|
||||||
// TODO: Add ipv6
|
|
||||||
|
|
||||||
// Example iptables-save output
|
// Example iptables-save output
|
||||||
iptablesSave := `# Generated by iptables-save v1.8.7 on Thu Dec 19 10:00:00 2024
|
iptablesSave := `# Generated by iptables-save v1.8.7 on Thu Dec 19 10:00:00 2024
|
||||||
*filter
|
*filter
|
||||||
@@ -467,17 +465,31 @@ Chain FORWARD (policy ACCEPT 0 packets, 0 bytes)
|
|||||||
Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes)
|
Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes)
|
||||||
pkts bytes target prot opt in out source destination`
|
pkts bytes target prot opt in out source destination`
|
||||||
|
|
||||||
// Example nftables output
|
// Example ip6tables-save output
|
||||||
|
ip6tablesSave := `# Generated by ip6tables-save v1.8.7 on Thu Dec 19 10:00:00 2024
|
||||||
|
*filter
|
||||||
|
:INPUT ACCEPT [0:0]
|
||||||
|
:FORWARD ACCEPT [0:0]
|
||||||
|
:OUTPUT ACCEPT [0:0]
|
||||||
|
-A INPUT -s fd00:1234::1/128 -j ACCEPT
|
||||||
|
-A INPUT -s 2607:f8b0:4005::1/128 -j DROP
|
||||||
|
-A FORWARD -s 2001:db8::/32 -d 2607:f8b0:4005::200e/128 -j ACCEPT
|
||||||
|
COMMIT`
|
||||||
|
|
||||||
|
// Example nftables output with IPv6
|
||||||
nftablesRules := `table inet filter {
|
nftablesRules := `table inet filter {
|
||||||
chain input {
|
chain input {
|
||||||
type filter hook input priority filter; policy accept;
|
type filter hook input priority filter; policy accept;
|
||||||
ip saddr 192.168.1.1 accept
|
ip saddr 192.168.1.1 accept
|
||||||
ip saddr 44.192.140.1 drop
|
ip saddr 44.192.140.1 drop
|
||||||
|
ip6 saddr 2607:f8b0:4005::1 drop
|
||||||
|
ip6 saddr fd00:1234::1 accept
|
||||||
}
|
}
|
||||||
chain forward {
|
chain forward {
|
||||||
type filter hook forward priority filter; policy accept;
|
type filter hook forward priority filter; policy accept;
|
||||||
ip saddr 10.0.0.0/8 drop
|
ip saddr 10.0.0.0/8 drop
|
||||||
ip saddr 44.192.140.0/24 ip daddr 52.84.12.34/24 accept
|
ip saddr 44.192.140.0/24 ip daddr 52.84.12.34/24 accept
|
||||||
|
ip6 saddr 2001:db8::/32 ip6 daddr 2607:f8b0:4005::200e/128 accept
|
||||||
}
|
}
|
||||||
}`
|
}`
|
||||||
|
|
||||||
@@ -540,4 +552,35 @@ Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes)
|
|||||||
assert.Contains(t, anonNftables, "table inet filter {")
|
assert.Contains(t, anonNftables, "table inet filter {")
|
||||||
assert.Contains(t, anonNftables, "chain input {")
|
assert.Contains(t, anonNftables, "chain input {")
|
||||||
assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;")
|
assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;")
|
||||||
|
|
||||||
|
// IPv6 public addresses in nftables should be anonymized
|
||||||
|
assert.NotContains(t, anonNftables, "2607:f8b0:4005::1")
|
||||||
|
assert.NotContains(t, anonNftables, "2607:f8b0:4005::200e")
|
||||||
|
assert.NotContains(t, anonNftables, "2001:db8::")
|
||||||
|
assert.Contains(t, anonNftables, "2001:db8:ffff::") // Default anonymous v6 range
|
||||||
|
|
||||||
|
// ULA addresses in nftables should remain unchanged (private)
|
||||||
|
assert.Contains(t, anonNftables, "fd00:1234::1")
|
||||||
|
|
||||||
|
// IPv6 nftables structure preserved
|
||||||
|
assert.Contains(t, anonNftables, "ip6 saddr")
|
||||||
|
assert.Contains(t, anonNftables, "ip6 daddr")
|
||||||
|
|
||||||
|
// Test ip6tables-save anonymization
|
||||||
|
anonIp6tablesSave := anonymizer.AnonymizeString(ip6tablesSave)
|
||||||
|
|
||||||
|
// ULA (private) IPv6 should remain unchanged
|
||||||
|
assert.Contains(t, anonIp6tablesSave, "fd00:1234::1/128")
|
||||||
|
|
||||||
|
// Public IPv6 addresses should be anonymized
|
||||||
|
assert.NotContains(t, anonIp6tablesSave, "2607:f8b0:4005::1")
|
||||||
|
assert.NotContains(t, anonIp6tablesSave, "2607:f8b0:4005::200e")
|
||||||
|
assert.NotContains(t, anonIp6tablesSave, "2001:db8::")
|
||||||
|
assert.Contains(t, anonIp6tablesSave, "2001:db8:ffff::") // Default anonymous v6 range
|
||||||
|
|
||||||
|
// Structure should be preserved
|
||||||
|
assert.Contains(t, anonIp6tablesSave, "*filter")
|
||||||
|
assert.Contains(t, anonIp6tablesSave, "COMMIT")
|
||||||
|
assert.Contains(t, anonIp6tablesSave, "-j DROP")
|
||||||
|
assert.Contains(t, anonIp6tablesSave, "-j ACCEPT")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -189,10 +189,10 @@ func (s *serviceViaListener) RuntimeIP() netip.Addr {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// evalListenAddress figure out the listen address for the DNS server
|
// evalListenAddress figures out the listen address for the DNS server.
|
||||||
// first check the 53 port availability on WG interface or lo, if not success
|
// IPv4-only: all peers have a v4 overlay address, and DNS config points to v4.
|
||||||
// pick a random port on WG interface for eBPF, if not success
|
// First checks port 53 on WG interface or lo, then tries eBPF on a random port,
|
||||||
// check the 5053 port availability on WG interface or lo without eBPF usage,
|
// then falls back to port 5053.
|
||||||
func (s *serviceViaListener) evalListenAddress() (netip.Addr, uint16, error) {
|
func (s *serviceViaListener) evalListenAddress() (netip.Addr, uint16, error) {
|
||||||
if s.customAddr != nil {
|
if s.customAddr != nil {
|
||||||
return s.customAddr.Addr(), s.customAddr.Port(), nil
|
return s.customAddr.Addr(), s.customAddr.Port(), nil
|
||||||
@@ -278,7 +278,7 @@ func (s *serviceViaListener) tryToUseeBPF() (ebpfMgr.Manager, uint16, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ebpfSrv := ebpf.GetEbpfManagerInstance()
|
ebpfSrv := ebpf.GetEbpfManagerInstance()
|
||||||
err = ebpfSrv.LoadDNSFwd(s.wgInterface.Address().IP.String(), int(port))
|
err = ebpfSrv.LoadDNSFwd(s.wgInterface.Address().IP, int(port))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to load DNS forwarder eBPF program, error: %s", err)
|
log.Warnf("failed to load DNS forwarder eBPF program, error: %s", err)
|
||||||
return nil, 0, false
|
return nil, 0, false
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
@@ -29,6 +30,12 @@ import (
|
|||||||
|
|
||||||
var currentMTU uint16 = iface.DefaultMTU
|
var currentMTU uint16 = iface.DefaultMTU
|
||||||
|
|
||||||
|
// privateClientIface is the subset of the WireGuard interface needed by GetClientPrivate.
|
||||||
|
type privateClientIface interface {
|
||||||
|
Name() string
|
||||||
|
Address() wgaddr.Address
|
||||||
|
}
|
||||||
|
|
||||||
func SetCurrentMTU(mtu uint16) {
|
func SetCurrentMTU(mtu uint16) {
|
||||||
currentMTU = mtu
|
currentMTU = mtu
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ func (u *upstreamResolver) isLocalResolver(upstream string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
|
func GetClientPrivate(_ privateClientIface, _ netip.Addr, dialTimeout time.Duration) (*dns.Client, error) {
|
||||||
return &dns.Client{
|
return &dns.Client{
|
||||||
Timeout: dialTimeout,
|
Timeout: dialTimeout,
|
||||||
Net: "udp",
|
Net: "udp",
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns
|
|||||||
return ExchangeWithFallback(ctx, client, r, upstream)
|
return ExchangeWithFallback(ctx, client, r, upstream)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
|
func GetClientPrivate(_ privateClientIface, _ netip.Addr, dialTimeout time.Duration) (*dns.Client, error) {
|
||||||
return &dns.Client{
|
return &dns.Client{
|
||||||
Timeout: dialTimeout,
|
Timeout: dialTimeout,
|
||||||
Net: "udp",
|
Net: "udp",
|
||||||
|
|||||||
@@ -19,11 +19,7 @@ import (
|
|||||||
|
|
||||||
type upstreamResolverIOS struct {
|
type upstreamResolverIOS struct {
|
||||||
*upstreamResolverBase
|
*upstreamResolverBase
|
||||||
lIP netip.Addr
|
wgIface WGIface
|
||||||
lNet netip.Prefix
|
|
||||||
lIPv6 netip.Addr
|
|
||||||
lNetV6 netip.Prefix
|
|
||||||
interfaceName string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUpstreamResolver(
|
func newUpstreamResolver(
|
||||||
@@ -37,11 +33,7 @@ func newUpstreamResolver(
|
|||||||
|
|
||||||
ios := &upstreamResolverIOS{
|
ios := &upstreamResolverIOS{
|
||||||
upstreamResolverBase: upstreamResolverBase,
|
upstreamResolverBase: upstreamResolverBase,
|
||||||
lIP: wgIface.Address().IP,
|
wgIface: wgIface,
|
||||||
lNet: wgIface.Address().Network,
|
|
||||||
lIPv6: wgIface.Address().IPv6,
|
|
||||||
lNetV6: wgIface.Address().IPv6Net,
|
|
||||||
interfaceName: wgIface.Name(),
|
|
||||||
}
|
}
|
||||||
ios.upstreamClient = ios
|
ios.upstreamClient = ios
|
||||||
|
|
||||||
@@ -69,24 +61,15 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
|||||||
} else {
|
} else {
|
||||||
upstreamIP = upstreamIP.Unmap()
|
upstreamIP = upstreamIP.Unmap()
|
||||||
}
|
}
|
||||||
needsPrivate := u.lNet.Contains(upstreamIP) ||
|
addr := u.wgIface.Address()
|
||||||
u.lNetV6.Contains(upstreamIP) ||
|
needsPrivate := addr.Network.Contains(upstreamIP) ||
|
||||||
|
addr.IPv6Net.Contains(upstreamIP) ||
|
||||||
(u.routeMatch != nil && u.routeMatch(upstreamIP))
|
(u.routeMatch != nil && u.routeMatch(upstreamIP))
|
||||||
if needsPrivate {
|
if needsPrivate {
|
||||||
var bindIP netip.Addr
|
log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream)
|
||||||
switch {
|
client, err = GetClientPrivate(u.wgIface, upstreamIP, timeout)
|
||||||
case upstreamIP.Is6() && u.lIPv6.IsValid():
|
if err != nil {
|
||||||
bindIP = u.lIPv6
|
return nil, 0, fmt.Errorf("create private client: %s", err)
|
||||||
case upstreamIP.Is4() && u.lIP.IsValid():
|
|
||||||
bindIP = u.lIP
|
|
||||||
}
|
|
||||||
|
|
||||||
if bindIP.IsValid() {
|
|
||||||
log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream)
|
|
||||||
client, err = GetClientPrivate(bindIP, u.interfaceName, timeout)
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, fmt.Errorf("create private client: %s", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -94,23 +77,29 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
|||||||
return ExchangeWithFallback(nil, client, r, upstream)
|
return ExchangeWithFallback(nil, client, r, upstream)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
|
// GetClientPrivate returns a new DNS client bound to the local IP of the Netbird interface.
|
||||||
// This method is needed for iOS
|
// It selects the v6 bind address when the upstream is IPv6 and the interface has one, otherwise v4.
|
||||||
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
|
func GetClientPrivate(iface privateClientIface, upstreamIP netip.Addr, dialTimeout time.Duration) (*dns.Client, error) {
|
||||||
index, err := getInterfaceIndex(interfaceName)
|
index, err := getInterfaceIndex(iface.Name())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
|
log.Debugf("unable to get interface index for %s: %s", iface.Name(), err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
addr := iface.Address()
|
||||||
|
bindIP := addr.IP
|
||||||
|
if upstreamIP.Is6() && addr.HasIPv6() {
|
||||||
|
bindIP = addr.IPv6
|
||||||
|
}
|
||||||
|
|
||||||
proto, opt := unix.IPPROTO_IP, unix.IP_BOUND_IF
|
proto, opt := unix.IPPROTO_IP, unix.IP_BOUND_IF
|
||||||
if ip.Is6() {
|
if bindIP.Is6() {
|
||||||
proto, opt = unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF
|
proto, opt = unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF
|
||||||
}
|
}
|
||||||
|
|
||||||
dialer := &net.Dialer{
|
dialer := &net.Dialer{
|
||||||
LocalAddr: net.UDPAddrFromAddrPort(netip.AddrPortFrom(ip, 0)),
|
LocalAddr: net.UDPAddrFromAddrPort(netip.AddrPortFrom(bindIP, 0)),
|
||||||
Timeout: dialTimeout,
|
Timeout: dialTimeout,
|
||||||
Control: func(network, address string, c syscall.RawConn) error {
|
Control: func(network, address string, c syscall.RawConn) error {
|
||||||
var operr error
|
var operr error
|
||||||
fn := func(s uintptr) {
|
fn := func(s uintptr) {
|
||||||
|
|||||||
@@ -80,6 +80,7 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IPv4-only: peers reach the forwarder via its v4 overlay address.
|
||||||
localAddr := m.wgIface.Address().IP
|
localAddr := m.wgIface.Address().IP
|
||||||
|
|
||||||
if localAddr.IsValid() && m.firewall != nil {
|
if localAddr.IsValid() && m.firewall != nil {
|
||||||
|
|||||||
@@ -2,7 +2,8 @@ package ebpf
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"net"
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
@@ -12,7 +13,7 @@ const (
|
|||||||
mapKeyDNSPort uint32 = 1
|
mapKeyDNSPort uint32 = 1
|
||||||
)
|
)
|
||||||
|
|
||||||
func (tf *GeneralManager) LoadDNSFwd(ip string, dnsPort int) error {
|
func (tf *GeneralManager) LoadDNSFwd(ip netip.Addr, dnsPort int) error {
|
||||||
log.Debugf("load eBPF DNS forwarder, watching addr: %s:53, redirect to port: %d", ip, dnsPort)
|
log.Debugf("load eBPF DNS forwarder, watching addr: %s:53, redirect to port: %d", ip, dnsPort)
|
||||||
tf.lock.Lock()
|
tf.lock.Lock()
|
||||||
defer tf.lock.Unlock()
|
defer tf.lock.Unlock()
|
||||||
@@ -22,7 +23,11 @@ func (tf *GeneralManager) LoadDNSFwd(ip string, dnsPort int) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = tf.bpfObjs.NbMapDnsIp.Put(mapKeyDNSIP, ip2int(ip))
|
if !ip.Is4() {
|
||||||
|
return fmt.Errorf("eBPF DNS forwarder only supports IPv4, got %s", ip)
|
||||||
|
}
|
||||||
|
ip4 := ip.As4()
|
||||||
|
err = tf.bpfObjs.NbMapDnsIp.Put(mapKeyDNSIP, binary.BigEndian.Uint32(ip4[:]))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -45,7 +50,3 @@ func (tf *GeneralManager) FreeDNSFwd() error {
|
|||||||
return tf.unsetFeatureFlag(featureFlagDnsForwarder)
|
return tf.unsetFeatureFlag(featureFlagDnsForwarder)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ip2int(ipString string) uint32 {
|
|
||||||
ip := net.ParseIP(ipString)
|
|
||||||
return binary.BigEndian.Uint32(ip.To4())
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
package manager
|
package manager
|
||||||
|
|
||||||
|
import "net/netip"
|
||||||
|
|
||||||
// Manager is used to load multiple eBPF programs. E.g., current DNS programs and WireGuard proxy
|
// Manager is used to load multiple eBPF programs. E.g., current DNS programs and WireGuard proxy
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
LoadDNSFwd(ip string, dnsPort int) error
|
LoadDNSFwd(ip netip.Addr, dnsPort int) error
|
||||||
FreeDNSFwd() error
|
FreeDNSFwd() error
|
||||||
LoadWgProxy(proxyPort, wgPort int) error
|
LoadWgProxy(proxyPort, wgPort int) error
|
||||||
FreeWGProxy() error
|
FreeWGProxy() error
|
||||||
|
|||||||
@@ -630,7 +630,7 @@ func (e *Engine) initFirewall() error {
|
|||||||
rosenpassPort := e.rpManager.GetAddress().Port
|
rosenpassPort := e.rpManager.GetAddress().Port
|
||||||
port := firewallManager.Port{Values: []uint16{uint16(rosenpassPort)}}
|
port := firewallManager.Port{Values: []uint16{uint16(rosenpassPort)}}
|
||||||
|
|
||||||
// this rule is static and will be torn down on engine down by the firewall manager
|
// IPv4-only: rosenpass peers connect via AllowedIps[0] which is always v4.
|
||||||
if _, err := e.firewall.AddPeerFiltering(
|
if _, err := e.firewall.AddPeerFiltering(
|
||||||
nil,
|
nil,
|
||||||
net.IP{0, 0, 0, 0},
|
net.IP{0, 0, 0, 0},
|
||||||
@@ -682,10 +682,15 @@ func (e *Engine) blockLanAccess() {
|
|||||||
|
|
||||||
log.Infof("blocking route LAN access for networks: %v", toBlock)
|
log.Infof("blocking route LAN access for networks: %v", toBlock)
|
||||||
v4 := netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
v4 := netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||||
|
v6 := netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||||
for _, network := range toBlock {
|
for _, network := range toBlock {
|
||||||
|
source := v4
|
||||||
|
if network.Addr().Is6() {
|
||||||
|
source = v6
|
||||||
|
}
|
||||||
if _, err := e.firewall.AddRouteFiltering(
|
if _, err := e.firewall.AddRouteFiltering(
|
||||||
nil,
|
nil,
|
||||||
[]netip.Prefix{v4},
|
[]netip.Prefix{source},
|
||||||
firewallManager.Network{Prefix: network},
|
firewallManager.Network{Prefix: network},
|
||||||
firewallManager.ProtocolALL,
|
firewallManager.ProtocolALL,
|
||||||
nil,
|
nil,
|
||||||
@@ -1494,10 +1499,10 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) {
|
|||||||
replacement := make([]peer.State, len(offlinePeers))
|
replacement := make([]peer.State, len(offlinePeers))
|
||||||
for i, offlinePeer := range offlinePeers {
|
for i, offlinePeer := range offlinePeers {
|
||||||
log.Debugf("added offline peer %s", offlinePeer.Fqdn)
|
log.Debugf("added offline peer %s", offlinePeer.Fqdn)
|
||||||
v4, v6 := splitAllowedIPs(offlinePeer.GetAllowedIps(), e.wgInterface.Address().IPv6Net)
|
v4, v6 := overlayAddrsFromAllowedIPs(offlinePeer.GetAllowedIps(), e.wgInterface.Address().IPv6Net)
|
||||||
replacement[i] = peer.State{
|
replacement[i] = peer.State{
|
||||||
IP: v4,
|
IP: addrToString(v4),
|
||||||
IPv6: v6,
|
IPv6: addrToString(v6),
|
||||||
PubKey: offlinePeer.GetWgPubKey(),
|
PubKey: offlinePeer.GetWgPubKey(),
|
||||||
FQDN: offlinePeer.GetFqdn(),
|
FQDN: offlinePeer.GetFqdn(),
|
||||||
ConnStatus: peer.StatusIdle,
|
ConnStatus: peer.StatusIdle,
|
||||||
@@ -1508,30 +1513,37 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) {
|
|||||||
e.statusRecorder.ReplaceOfflinePeers(replacement)
|
e.statusRecorder.ReplaceOfflinePeers(replacement)
|
||||||
}
|
}
|
||||||
|
|
||||||
// splitAllowedIPs separates the peer's overlay v4 (/32) and v6 (/128) addresses
|
// overlayAddrsFromAllowedIPs extracts the peer's v4 and v6 overlay addresses
|
||||||
// from a list of AllowedIPs CIDRs. The v6 address is only matched if it falls
|
// from AllowedIPs strings. Only host routes (/32, /128) are considered; v6 must
|
||||||
// within ourV6Net (the local overlay v6 subnet), to avoid confusing routed /128
|
// fall within ourV6Net to distinguish overlay addresses from routed prefixes.
|
||||||
// prefixes with the peer's overlay address.
|
func overlayAddrsFromAllowedIPs(allowedIPs []string, ourV6Net netip.Prefix) (v4, v6 netip.Addr) {
|
||||||
func splitAllowedIPs(allowedIPs []string, ourV6Net netip.Prefix) (v4, v6 string) {
|
|
||||||
for _, cidr := range allowedIPs {
|
for _, cidr := range allowedIPs {
|
||||||
prefix, err := netip.ParsePrefix(cidr)
|
prefix, err := netip.ParsePrefix(cidr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to parse AllowedIP %q: %v", cidr, err)
|
log.Warnf("failed to parse AllowedIP %q: %v", cidr, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
addr := prefix.Addr().Unmap()
|
||||||
switch {
|
switch {
|
||||||
case prefix.Addr().Is4() && prefix.Bits() == 32 && v4 == "":
|
case addr.Is4() && prefix.Bits() == 32 && !v4.IsValid():
|
||||||
v4 = prefix.Addr().String()
|
v4 = addr
|
||||||
case prefix.Addr().Is6() && prefix.Bits() == 128 && ourV6Net.Contains(prefix.Addr()) && v6 == "":
|
case addr.Is6() && prefix.Bits() == 128 && ourV6Net.Contains(addr) && !v6.IsValid():
|
||||||
v6 = prefix.Addr().String()
|
v6 = addr
|
||||||
}
|
}
|
||||||
if v4 != "" && v6 != "" {
|
if v4.IsValid() && v6.IsValid() {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func addrToString(addr netip.Addr) string {
|
||||||
|
if !addr.IsValid() {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return addr.String()
|
||||||
|
}
|
||||||
|
|
||||||
// addNewPeers adds peers that were not know before but arrived from the Management service with the update
|
// addNewPeers adds peers that were not know before but arrived from the Management service with the update
|
||||||
func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||||
for _, p := range peersUpdate {
|
for _, p := range peersUpdate {
|
||||||
@@ -1572,8 +1584,8 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
|
|||||||
return fmt.Errorf("create peer connection: %w", err)
|
return fmt.Errorf("create peer connection: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
peerV4, peerV6 := splitAllowedIPs(peerConfig.GetAllowedIps(), e.wgInterface.Address().IPv6Net)
|
peerV4, peerV6 := overlayAddrsFromAllowedIPs(peerConfig.GetAllowedIps(), e.wgInterface.Address().IPv6Net)
|
||||||
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn, peerV4, peerV6)
|
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn, addrToString(peerV4), addrToString(peerV6))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
|
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
|
||||||
}
|
}
|
||||||
@@ -2355,8 +2367,7 @@ func getInterfacePrefixes() ([]netip.Prefix, error) {
|
|||||||
prefix := netip.PrefixFrom(addr.Unmap(), ones).Masked()
|
prefix := netip.PrefixFrom(addr.Unmap(), ones).Masked()
|
||||||
ip := prefix.Addr()
|
ip := prefix.Addr()
|
||||||
|
|
||||||
// TODO: add IPv6
|
if ip.IsLoopback() || ip.IsMulticast() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
||||||
if !ip.Is4() || ip.IsLoopback() || ip.IsMulticast() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -145,13 +145,13 @@ func (e *Engine) extractPeerSSHInfo(remotePeers []*mgmProto.RemotePeerConfig) []
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
peerIP, peerIPv6 := e.extractPeerIPs(peerConfig)
|
peerV4, peerV6 := overlayAddrsFromAllowedIPs(peerConfig.GetAllowedIps(), e.wgInterface.Address().IPv6Net)
|
||||||
hostname := e.extractHostname(peerConfig)
|
hostname := e.extractHostname(peerConfig)
|
||||||
|
|
||||||
peerInfo = append(peerInfo, sshconfig.PeerSSHInfo{
|
peerInfo = append(peerInfo, sshconfig.PeerSSHInfo{
|
||||||
Hostname: hostname,
|
Hostname: hostname,
|
||||||
IP: peerIP,
|
IP: peerV4,
|
||||||
IPv6: peerIPv6,
|
IPv6: peerV6,
|
||||||
FQDN: peerConfig.GetFqdn(),
|
FQDN: peerConfig.GetFqdn(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -159,28 +159,6 @@ func (e *Engine) extractPeerSSHInfo(remotePeers []*mgmProto.RemotePeerConfig) []
|
|||||||
return peerInfo
|
return peerInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractPeerIPs extracts IPv4 and IPv6 overlay addresses from peer's allowed IPs.
|
|
||||||
// Only considers host routes (/32, /128) within the overlay networks to avoid
|
|
||||||
// picking up routed prefixes or static routes like 2620:fe::fe/128.
|
|
||||||
func (e *Engine) extractPeerIPs(peerConfig *mgmProto.RemotePeerConfig) (v4, v6 netip.Addr) {
|
|
||||||
wgAddr := e.wgInterface.Address()
|
|
||||||
for _, allowedIP := range peerConfig.GetAllowedIps() {
|
|
||||||
prefix, err := netip.ParsePrefix(allowedIP)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to parse AllowedIP %q: %v", allowedIP, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
addr := prefix.Addr().Unmap()
|
|
||||||
switch {
|
|
||||||
case addr.Is4() && prefix.Bits() == 32 && wgAddr.Network.Contains(addr) && !v4.IsValid():
|
|
||||||
v4 = addr
|
|
||||||
case addr.Is6() && prefix.Bits() == 128 && wgAddr.IPv6Net.IsValid() && wgAddr.IPv6Net.Contains(addr) && !v6.IsValid():
|
|
||||||
v6 = addr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return v4, v6
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractHostname extracts short hostname from FQDN
|
// extractHostname extracts short hostname from FQDN
|
||||||
func (e *Engine) extractHostname(peerConfig *mgmProto.RemotePeerConfig) string {
|
func (e *Engine) extractHostname(peerConfig *mgmProto.RemotePeerConfig) string {
|
||||||
fqdn := peerConfig.GetFqdn()
|
fqdn := peerConfig.GetFqdn()
|
||||||
|
|||||||
@@ -1837,7 +1837,7 @@ func TestFilterAllowedIPs(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSplitAllowedIPs(t *testing.T) {
|
func TestOverlayAddrsFromAllowedIPs(t *testing.T) {
|
||||||
ourV6Net := netip.MustParsePrefix("fd00:1234:5678:abcd::/64")
|
ourV6Net := netip.MustParsePrefix("fd00:1234:5678:abcd::/64")
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
@@ -1900,9 +1900,17 @@ func TestSplitAllowedIPs(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
v4, v6 := splitAllowedIPs(tt.allowedIPs, tt.ourV6Net)
|
v4, v6 := overlayAddrsFromAllowedIPs(tt.allowedIPs, tt.ourV6Net)
|
||||||
assert.Equal(t, tt.wantV4, v4, "v4")
|
if tt.wantV4 == "" {
|
||||||
assert.Equal(t, tt.wantV6, v6, "v6")
|
assert.False(t, v4.IsValid(), "expected no v4")
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, tt.wantV4, v4.String(), "v4")
|
||||||
|
}
|
||||||
|
if tt.wantV6 == "" {
|
||||||
|
assert.False(t, v6.IsValid(), "expected no v6")
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, tt.wantV6, v6.String(), "v6")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ func NewBindListener(wgIface WgInterface, bind device.EndpointManager, cfg lazyc
|
|||||||
// deriveFakeIP creates a deterministic fake IP for bind mode based on peer's NetBird IP.
|
// deriveFakeIP creates a deterministic fake IP for bind mode based on peer's NetBird IP.
|
||||||
// Maps peer IP 100.64.x.y to fake IP 127.2.x.y (similar to relay proxy using 127.1.x.y).
|
// Maps peer IP 100.64.x.y to fake IP 127.2.x.y (similar to relay proxy using 127.1.x.y).
|
||||||
// It finds the peer's actual NetBird IP by checking which allowedIP is in the same subnet as our WG interface.
|
// It finds the peer's actual NetBird IP by checking which allowedIP is in the same subnet as our WG interface.
|
||||||
|
// For IPv6-only peers, the last two bytes of the v6 address are used.
|
||||||
func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, error) {
|
func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, error) {
|
||||||
if len(allowedIPs) == 0 {
|
if len(allowedIPs) == 0 {
|
||||||
return netip.Addr{}, fmt.Errorf("no allowed IPs for peer")
|
return netip.Addr{}, fmt.Errorf("no allowed IPs for peer")
|
||||||
@@ -64,6 +65,7 @@ func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, e
|
|||||||
|
|
||||||
ourNetwork := wgIface.Address().Network
|
ourNetwork := wgIface.Address().Network
|
||||||
|
|
||||||
|
// Try v4 first (preferred: deterministic from overlay IP)
|
||||||
var peerIP netip.Addr
|
var peerIP netip.Addr
|
||||||
for _, allowedIP := range allowedIPs {
|
for _, allowedIP := range allowedIPs {
|
||||||
ip := allowedIP.Addr()
|
ip := allowedIP.Addr()
|
||||||
@@ -76,13 +78,24 @@ func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, e
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !peerIP.IsValid() {
|
if peerIP.IsValid() {
|
||||||
return netip.Addr{}, fmt.Errorf("no peer NetBird IP found in allowed IPs")
|
octets := peerIP.As4()
|
||||||
|
return netip.AddrFrom4([4]byte{127, 2, octets[2], octets[3]}), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
octets := peerIP.As4()
|
// Fallback: use last two bytes of first v6 overlay IP
|
||||||
fakeIP := netip.AddrFrom4([4]byte{127, 2, octets[2], octets[3]})
|
addr := wgIface.Address()
|
||||||
return fakeIP, nil
|
if addr.IPv6Net.IsValid() {
|
||||||
|
for _, allowedIP := range allowedIPs {
|
||||||
|
ip := allowedIP.Addr()
|
||||||
|
if ip.Is6() && addr.IPv6Net.Contains(ip) {
|
||||||
|
raw := ip.As16()
|
||||||
|
return netip.AddrFrom4([4]byte{127, 2, raw[14], raw[15]}), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return netip.Addr{}, fmt.Errorf("no peer NetBird IP found in allowed IPs")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *BindListener) setupLazyConn() error {
|
func (d *BindListener) setupLazyConn() error {
|
||||||
|
|||||||
@@ -1055,7 +1055,11 @@ func (d *Status) notifyPeerListChanged() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) notifyAddressChanged() {
|
func (d *Status) notifyAddressChanged() {
|
||||||
d.notifier.localAddressChanged(d.localPeer.FQDN, d.localPeer.IP)
|
addr := d.localPeer.IP
|
||||||
|
if d.localPeer.IPv6 != "" {
|
||||||
|
addr = addr + "\n" + d.localPeer.IPv6
|
||||||
|
}
|
||||||
|
d.notifier.localAddressChanged(d.localPeer.FQDN, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) numOfPeers() int {
|
func (d *Status) numOfPeers() int {
|
||||||
|
|||||||
@@ -3,9 +3,8 @@ package client
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net/netip"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -566,7 +565,7 @@ func HandlerFromRoute(params common.HandlerParams) RouteHandler {
|
|||||||
return dnsinterceptor.New(params)
|
return dnsinterceptor.New(params)
|
||||||
case handlerTypeDynamic:
|
case handlerTypeDynamic:
|
||||||
dns := nbdns.NewServiceViaMemory(params.WgInterface)
|
dns := nbdns.NewServiceViaMemory(params.WgInterface)
|
||||||
dnsAddr := net.JoinHostPort(dns.RuntimeIP().String(), strconv.Itoa(dns.RuntimePort()))
|
dnsAddr := netip.AddrPortFrom(dns.RuntimeIP(), uint16(dns.RuntimePort()))
|
||||||
return dynamic.NewRoute(params, dnsAddr)
|
return dynamic.NewRoute(params, dnsAddr)
|
||||||
default:
|
default:
|
||||||
return static.NewRoute(params)
|
return static.NewRoute(params)
|
||||||
|
|||||||
@@ -582,7 +582,7 @@ func (d *DnsInterceptor) queryUpstreamDNS(ctx context.Context, w dns.ResponseWri
|
|||||||
if nsNet != nil {
|
if nsNet != nil {
|
||||||
reply, err = nbdns.ExchangeWithNetstack(ctx, nsNet, r, upstream)
|
reply, err = nbdns.ExchangeWithNetstack(ctx, nsNet, r, upstream)
|
||||||
} else {
|
} else {
|
||||||
client, clientErr := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout)
|
client, clientErr := nbdns.GetClientPrivate(d.wgInterface, upstreamIP, dnsTimeout)
|
||||||
if clientErr != nil {
|
if clientErr != nil {
|
||||||
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", clientErr))
|
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", clientErr))
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -50,10 +50,10 @@ type Route struct {
|
|||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
wgInterface iface.WGIface
|
wgInterface iface.WGIface
|
||||||
resolverAddr string
|
resolverAddr netip.AddrPort
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRoute(params common.HandlerParams, resolverAddr string) *Route {
|
func NewRoute(params common.HandlerParams, resolverAddr netip.AddrPort) *Route {
|
||||||
return &Route{
|
return &Route{
|
||||||
route: params.Route,
|
route: params.Route,
|
||||||
routeRefCounter: params.RouteRefCounter,
|
routeRefCounter: params.RouteRefCounter,
|
||||||
|
|||||||
@@ -17,37 +17,47 @@ import (
|
|||||||
const dialTimeout = 10 * time.Second
|
const dialTimeout = 10 * time.Second
|
||||||
|
|
||||||
func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) {
|
func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) {
|
||||||
privateClient, err := nbdns.GetClientPrivate(r.wgInterface.Address().IP, r.wgInterface.Name(), dialTimeout)
|
privateClient, err := nbdns.GetClientPrivate(r.wgInterface, r.resolverAddr.Addr(), dialTimeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error while creating private client: %s", err)
|
return nil, fmt.Errorf("error while creating private client: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
msg := new(dns.Msg)
|
fqdn := dns.Fqdn(domain.PunycodeString())
|
||||||
msg.SetQuestion(dns.Fqdn(domain.PunycodeString()), dns.TypeA)
|
|
||||||
|
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
|
|
||||||
response, _, err := nbdns.ExchangeWithFallback(nil, privateClient, msg, r.resolverAddr)
|
var ips []net.IP
|
||||||
if err != nil {
|
var queryErr error
|
||||||
return nil, fmt.Errorf("DNS query for %s failed after %s: %s ", domain.SafeString(), time.Since(startTime), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if response.Rcode != dns.RcodeSuccess {
|
for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} {
|
||||||
return nil, fmt.Errorf("dns response code: %s", dns.RcodeToString[response.Rcode])
|
msg := new(dns.Msg)
|
||||||
}
|
msg.SetQuestion(fqdn, qtype)
|
||||||
|
|
||||||
ips := make([]net.IP, 0)
|
response, _, err := nbdns.ExchangeWithFallback(nil, privateClient, msg, r.resolverAddr.String())
|
||||||
|
if err != nil {
|
||||||
for _, answ := range response.Answer {
|
if queryErr == nil {
|
||||||
if aRecord, ok := answ.(*dns.A); ok {
|
queryErr = fmt.Errorf("DNS query for %s (type %d) after %s: %w", domain.SafeString(), qtype, time.Since(startTime), err)
|
||||||
ips = append(ips, aRecord.A)
|
}
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
if aaaaRecord, ok := answ.(*dns.AAAA); ok {
|
|
||||||
ips = append(ips, aaaaRecord.AAAA)
|
if response.Rcode != dns.RcodeSuccess {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, answ := range response.Answer {
|
||||||
|
if aRecord, ok := answ.(*dns.A); ok {
|
||||||
|
ips = append(ips, aRecord.A)
|
||||||
|
}
|
||||||
|
if aaaaRecord, ok := answ.(*dns.AAAA); ok {
|
||||||
|
ips = append(ips, aaaaRecord.AAAA)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(ips) == 0 {
|
if len(ips) == 0 {
|
||||||
|
if queryErr != nil {
|
||||||
|
return nil, queryErr
|
||||||
|
}
|
||||||
return nil, fmt.Errorf("no A or AAAA records found for %s", domain.SafeString())
|
return nil, fmt.Errorf("no A or AAAA records found for %s", domain.SafeString())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,93 +1,145 @@
|
|||||||
package fakeip
|
package fakeip
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Manager manages allocation of fake IPs from the 240.0.0.0/8 block
|
var (
|
||||||
type Manager struct {
|
// 240.0.0.1 - 240.255.255.254, block 240.0.0.0/8 (reserved, RFC 1112)
|
||||||
mu sync.Mutex
|
v4Base = netip.AddrFrom4([4]byte{240, 0, 0, 1})
|
||||||
nextIP netip.Addr // Next IP to allocate
|
v4Max = netip.AddrFrom4([4]byte{240, 255, 255, 254})
|
||||||
|
v4Block = netip.PrefixFrom(netip.AddrFrom4([4]byte{240, 0, 0, 0}), 8)
|
||||||
|
|
||||||
|
// 0100::1 - 0100::ffff:ffff:ffff:fffe, block 0100::/64 (discard, RFC 6666)
|
||||||
|
v6Base = netip.AddrFrom16([16]byte{0x01, 0x00, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01})
|
||||||
|
v6Max = netip.AddrFrom16([16]byte{0x01, 0x00, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe})
|
||||||
|
v6Block = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x01, 0x00}), 64)
|
||||||
|
)
|
||||||
|
|
||||||
|
// fakeIPPool holds the allocation state for a single address family.
|
||||||
|
type fakeIPPool struct {
|
||||||
|
nextIP netip.Addr
|
||||||
|
baseIP netip.Addr
|
||||||
|
maxIP netip.Addr
|
||||||
|
block netip.Prefix
|
||||||
allocated map[netip.Addr]netip.Addr // real IP -> fake IP
|
allocated map[netip.Addr]netip.Addr // real IP -> fake IP
|
||||||
fakeToReal map[netip.Addr]netip.Addr // fake IP -> real IP
|
fakeToReal map[netip.Addr]netip.Addr // fake IP -> real IP
|
||||||
baseIP netip.Addr // First usable IP: 240.0.0.1
|
|
||||||
maxIP netip.Addr // Last usable IP: 240.255.255.254
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewManager creates a new fake IP manager using 240.0.0.0/8 block
|
func newPool(base, maxAddr netip.Addr, block netip.Prefix) *fakeIPPool {
|
||||||
func NewManager() *Manager {
|
return &fakeIPPool{
|
||||||
baseIP := netip.AddrFrom4([4]byte{240, 0, 0, 1})
|
nextIP: base,
|
||||||
maxIP := netip.AddrFrom4([4]byte{240, 255, 255, 254})
|
baseIP: base,
|
||||||
|
maxIP: maxAddr,
|
||||||
return &Manager{
|
block: block,
|
||||||
nextIP: baseIP,
|
|
||||||
allocated: make(map[netip.Addr]netip.Addr),
|
allocated: make(map[netip.Addr]netip.Addr),
|
||||||
fakeToReal: make(map[netip.Addr]netip.Addr),
|
fakeToReal: make(map[netip.Addr]netip.Addr),
|
||||||
baseIP: baseIP,
|
|
||||||
maxIP: maxIP,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AllocateFakeIP allocates a fake IP for the given real IP
|
// allocate allocates a fake IP for the given real IP.
|
||||||
// Returns the fake IP, or existing fake IP if already allocated
|
// Returns the existing fake IP if already allocated.
|
||||||
func (m *Manager) AllocateFakeIP(realIP netip.Addr) (netip.Addr, error) {
|
func (p *fakeIPPool) allocate(realIP netip.Addr) (netip.Addr, error) {
|
||||||
if !realIP.Is4() {
|
if fakeIP, exists := p.allocated[realIP]; exists {
|
||||||
return netip.Addr{}, fmt.Errorf("only IPv4 addresses supported")
|
|
||||||
}
|
|
||||||
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
|
|
||||||
if fakeIP, exists := m.allocated[realIP]; exists {
|
|
||||||
return fakeIP, nil
|
return fakeIP, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
startIP := m.nextIP
|
startIP := p.nextIP
|
||||||
for {
|
for {
|
||||||
currentIP := m.nextIP
|
currentIP := p.nextIP
|
||||||
|
|
||||||
// Advance to next IP, wrapping at boundary
|
// Advance to next IP, wrapping at boundary
|
||||||
if m.nextIP.Compare(m.maxIP) >= 0 {
|
if p.nextIP.Compare(p.maxIP) >= 0 {
|
||||||
m.nextIP = m.baseIP
|
p.nextIP = p.baseIP
|
||||||
} else {
|
} else {
|
||||||
m.nextIP = m.nextIP.Next()
|
p.nextIP = p.nextIP.Next()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if current IP is available
|
if _, inUse := p.fakeToReal[currentIP]; !inUse {
|
||||||
if _, inUse := m.fakeToReal[currentIP]; !inUse {
|
p.allocated[realIP] = currentIP
|
||||||
m.allocated[realIP] = currentIP
|
p.fakeToReal[currentIP] = realIP
|
||||||
m.fakeToReal[currentIP] = realIP
|
|
||||||
return currentIP, nil
|
return currentIP, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prevent infinite loop if all IPs exhausted
|
if p.nextIP.Compare(startIP) == 0 {
|
||||||
if m.nextIP.Compare(startIP) == 0 {
|
return netip.Addr{}, fmt.Errorf("no more fake IPs available in %s block", p.block)
|
||||||
return netip.Addr{}, fmt.Errorf("no more fake IPs available in 240.0.0.0/8 block")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetFakeIP returns the fake IP for a real IP if it exists
|
// Manager manages allocation of fake IPs for dynamic DNS routes.
|
||||||
|
// IPv4 uses 240.0.0.0/8 (reserved), IPv6 uses 0100::/64 (discard, RFC 6666).
|
||||||
|
type Manager struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
v4 *fakeIPPool
|
||||||
|
v6 *fakeIPPool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager creates a new fake IP manager.
|
||||||
|
func NewManager() *Manager {
|
||||||
|
return &Manager{
|
||||||
|
v4: newPool(v4Base, v4Max, v4Block),
|
||||||
|
v6: newPool(v6Base, v6Max, v6Block),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) pool(ip netip.Addr) *fakeIPPool {
|
||||||
|
if ip.Is6() {
|
||||||
|
return m.v6
|
||||||
|
}
|
||||||
|
return m.v4
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllocateFakeIP allocates a fake IP for the given real IP.
|
||||||
|
func (m *Manager) AllocateFakeIP(realIP netip.Addr) (netip.Addr, error) {
|
||||||
|
realIP = realIP.Unmap()
|
||||||
|
if !realIP.IsValid() {
|
||||||
|
return netip.Addr{}, errors.New("invalid IP address")
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
return m.pool(realIP).allocate(realIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFakeIP returns the fake IP for a real IP if it exists.
|
||||||
func (m *Manager) GetFakeIP(realIP netip.Addr) (netip.Addr, bool) {
|
func (m *Manager) GetFakeIP(realIP netip.Addr) (netip.Addr, bool) {
|
||||||
|
realIP = realIP.Unmap()
|
||||||
|
if !realIP.IsValid() {
|
||||||
|
return netip.Addr{}, false
|
||||||
|
}
|
||||||
|
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
fakeIP, exists := m.allocated[realIP]
|
fakeIP, ok := m.pool(realIP).allocated[realIP]
|
||||||
return fakeIP, exists
|
return fakeIP, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRealIP returns the real IP for a fake IP if it exists, otherwise false
|
// GetRealIP returns the real IP for a fake IP if it exists.
|
||||||
func (m *Manager) GetRealIP(fakeIP netip.Addr) (netip.Addr, bool) {
|
func (m *Manager) GetRealIP(fakeIP netip.Addr) (netip.Addr, bool) {
|
||||||
|
fakeIP = fakeIP.Unmap()
|
||||||
|
if !fakeIP.IsValid() {
|
||||||
|
return netip.Addr{}, false
|
||||||
|
}
|
||||||
|
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
realIP, exists := m.fakeToReal[fakeIP]
|
realIP, ok := m.pool(fakeIP).fakeToReal[fakeIP]
|
||||||
return realIP, exists
|
return realIP, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetFakeIPBlock returns the fake IP block used by this manager
|
// GetFakeIPBlock returns the v4 fake IP block used by this manager.
|
||||||
func (m *Manager) GetFakeIPBlock() netip.Prefix {
|
func (m *Manager) GetFakeIPBlock() netip.Prefix {
|
||||||
return netip.MustParsePrefix("240.0.0.0/8")
|
return m.v4.block
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFakeIPv6Block returns the v6 fake IP block used by this manager.
|
||||||
|
func (m *Manager) GetFakeIPv6Block() netip.Prefix {
|
||||||
|
return m.v6.block
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,16 +9,16 @@ import (
|
|||||||
func TestNewManager(t *testing.T) {
|
func TestNewManager(t *testing.T) {
|
||||||
manager := NewManager()
|
manager := NewManager()
|
||||||
|
|
||||||
if manager.baseIP.String() != "240.0.0.1" {
|
if manager.v4.baseIP.String() != "240.0.0.1" {
|
||||||
t.Errorf("Expected base IP 240.0.0.1, got %s", manager.baseIP.String())
|
t.Errorf("Expected v4 base IP 240.0.0.1, got %s", manager.v4.baseIP.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
if manager.maxIP.String() != "240.255.255.254" {
|
if manager.v4.maxIP.String() != "240.255.255.254" {
|
||||||
t.Errorf("Expected max IP 240.255.255.254, got %s", manager.maxIP.String())
|
t.Errorf("Expected v4 max IP 240.255.255.254, got %s", manager.v4.maxIP.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
if manager.nextIP.Compare(manager.baseIP) != 0 {
|
if manager.v6.baseIP.String() != "100::1" {
|
||||||
t.Errorf("Expected nextIP to start at baseIP")
|
t.Errorf("Expected v6 base IP 100::1, got %s", manager.v6.baseIP.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -35,7 +35,6 @@ func TestAllocateFakeIP(t *testing.T) {
|
|||||||
t.Error("Fake IP should be IPv4")
|
t.Error("Fake IP should be IPv4")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check it's in the correct range
|
|
||||||
if fakeIP.As4()[0] != 240 {
|
if fakeIP.As4()[0] != 240 {
|
||||||
t.Errorf("Fake IP should be in 240.0.0.0/8 range, got %s", fakeIP.String())
|
t.Errorf("Fake IP should be in 240.0.0.0/8 range, got %s", fakeIP.String())
|
||||||
}
|
}
|
||||||
@@ -51,13 +50,31 @@ func TestAllocateFakeIP(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAllocateFakeIPIPv6Rejection(t *testing.T) {
|
func TestAllocateFakeIPv6(t *testing.T) {
|
||||||
manager := NewManager()
|
manager := NewManager()
|
||||||
realIPv6 := netip.MustParseAddr("2001:db8::1")
|
realIP := netip.MustParseAddr("2001:db8::1")
|
||||||
|
|
||||||
_, err := manager.AllocateFakeIP(realIPv6)
|
fakeIP, err := manager.AllocateFakeIP(realIP)
|
||||||
if err == nil {
|
if err != nil {
|
||||||
t.Error("Expected error for IPv6 address")
|
t.Fatalf("Failed to allocate fake IPv6: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !fakeIP.Is6() {
|
||||||
|
t.Error("Fake IP should be IPv6")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !netip.MustParsePrefix("100::/64").Contains(fakeIP) {
|
||||||
|
t.Errorf("Fake IP should be in 100::/64 range, got %s", fakeIP.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should return same fake IP for same real IP
|
||||||
|
fakeIP2, err := manager.AllocateFakeIP(realIP)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get existing fake IPv6: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fakeIP.Compare(fakeIP2) != 0 {
|
||||||
|
t.Errorf("Expected same fake IP, got %s and %s", fakeIP.String(), fakeIP2.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -65,13 +82,11 @@ func TestGetFakeIP(t *testing.T) {
|
|||||||
manager := NewManager()
|
manager := NewManager()
|
||||||
realIP := netip.MustParseAddr("1.1.1.1")
|
realIP := netip.MustParseAddr("1.1.1.1")
|
||||||
|
|
||||||
// Should not exist initially
|
|
||||||
_, exists := manager.GetFakeIP(realIP)
|
_, exists := manager.GetFakeIP(realIP)
|
||||||
if exists {
|
if exists {
|
||||||
t.Error("Fake IP should not exist before allocation")
|
t.Error("Fake IP should not exist before allocation")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allocate and check
|
|
||||||
expectedFakeIP, err := manager.AllocateFakeIP(realIP)
|
expectedFakeIP, err := manager.AllocateFakeIP(realIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to allocate: %v", err)
|
t.Fatalf("Failed to allocate: %v", err)
|
||||||
@@ -87,12 +102,30 @@ func TestGetFakeIP(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetRealIPv6(t *testing.T) {
|
||||||
|
manager := NewManager()
|
||||||
|
realIP := netip.MustParseAddr("2001:db8::1")
|
||||||
|
|
||||||
|
fakeIP, err := manager.AllocateFakeIP(realIP)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to allocate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
gotReal, exists := manager.GetRealIP(fakeIP)
|
||||||
|
if !exists {
|
||||||
|
t.Error("Real IP should exist for allocated fake IP")
|
||||||
|
}
|
||||||
|
|
||||||
|
if gotReal.Compare(realIP) != 0 {
|
||||||
|
t.Errorf("Expected real IP %s, got %s", realIP, gotReal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestMultipleAllocations(t *testing.T) {
|
func TestMultipleAllocations(t *testing.T) {
|
||||||
manager := NewManager()
|
manager := NewManager()
|
||||||
|
|
||||||
allocations := make(map[netip.Addr]netip.Addr)
|
allocations := make(map[netip.Addr]netip.Addr)
|
||||||
|
|
||||||
// Allocate multiple IPs
|
|
||||||
for i := 1; i <= 100; i++ {
|
for i := 1; i <= 100; i++ {
|
||||||
realIP := netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
|
realIP := netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
|
||||||
fakeIP, err := manager.AllocateFakeIP(realIP)
|
fakeIP, err := manager.AllocateFakeIP(realIP)
|
||||||
@@ -100,7 +133,6 @@ func TestMultipleAllocations(t *testing.T) {
|
|||||||
t.Fatalf("Failed to allocate fake IP for %s: %v", realIP.String(), err)
|
t.Fatalf("Failed to allocate fake IP for %s: %v", realIP.String(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for duplicates
|
|
||||||
for _, existingFake := range allocations {
|
for _, existingFake := range allocations {
|
||||||
if fakeIP.Compare(existingFake) == 0 {
|
if fakeIP.Compare(existingFake) == 0 {
|
||||||
t.Errorf("Duplicate fake IP allocated: %s", fakeIP.String())
|
t.Errorf("Duplicate fake IP allocated: %s", fakeIP.String())
|
||||||
@@ -110,7 +142,6 @@ func TestMultipleAllocations(t *testing.T) {
|
|||||||
allocations[realIP] = fakeIP
|
allocations[realIP] = fakeIP
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify all allocations can be retrieved
|
|
||||||
for realIP, expectedFake := range allocations {
|
for realIP, expectedFake := range allocations {
|
||||||
actualFake, exists := manager.GetFakeIP(realIP)
|
actualFake, exists := manager.GetFakeIP(realIP)
|
||||||
if !exists {
|
if !exists {
|
||||||
@@ -124,11 +155,13 @@ func TestMultipleAllocations(t *testing.T) {
|
|||||||
|
|
||||||
func TestGetFakeIPBlock(t *testing.T) {
|
func TestGetFakeIPBlock(t *testing.T) {
|
||||||
manager := NewManager()
|
manager := NewManager()
|
||||||
block := manager.GetFakeIPBlock()
|
|
||||||
|
|
||||||
expected := "240.0.0.0/8"
|
if block := manager.GetFakeIPBlock(); block.String() != "240.0.0.0/8" {
|
||||||
if block.String() != expected {
|
t.Errorf("Expected 240.0.0.0/8, got %s", block.String())
|
||||||
t.Errorf("Expected %s, got %s", expected, block.String())
|
}
|
||||||
|
|
||||||
|
if block := manager.GetFakeIPv6Block(); block.String() != "100::/64" {
|
||||||
|
t.Errorf("Expected 100::/64, got %s", block.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -141,7 +174,6 @@ func TestConcurrentAccess(t *testing.T) {
|
|||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
results := make(chan netip.Addr, numGoroutines*allocationsPerGoroutine)
|
results := make(chan netip.Addr, numGoroutines*allocationsPerGoroutine)
|
||||||
|
|
||||||
// Concurrent allocations
|
|
||||||
for i := 0; i < numGoroutines; i++ {
|
for i := 0; i < numGoroutines; i++ {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(goroutineID int) {
|
go func(goroutineID int) {
|
||||||
@@ -161,7 +193,6 @@ func TestConcurrentAccess(t *testing.T) {
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
close(results)
|
close(results)
|
||||||
|
|
||||||
// Check for duplicates
|
|
||||||
seen := make(map[netip.Addr]bool)
|
seen := make(map[netip.Addr]bool)
|
||||||
count := 0
|
count := 0
|
||||||
for fakeIP := range results {
|
for fakeIP := range results {
|
||||||
@@ -178,47 +209,61 @@ func TestConcurrentAccess(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestIPExhaustion(t *testing.T) {
|
func TestIPExhaustion(t *testing.T) {
|
||||||
// Create a manager with limited range for testing
|
|
||||||
manager := &Manager{
|
manager := &Manager{
|
||||||
nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
|
v4: newPool(
|
||||||
allocated: make(map[netip.Addr]netip.Addr),
|
netip.AddrFrom4([4]byte{240, 0, 0, 1}),
|
||||||
fakeToReal: make(map[netip.Addr]netip.Addr),
|
netip.AddrFrom4([4]byte{240, 0, 0, 3}),
|
||||||
baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
|
netip.MustParsePrefix("240.0.0.0/8"),
|
||||||
maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 3}), // Only 3 IPs available
|
),
|
||||||
|
v6: newPool(
|
||||||
|
netip.MustParseAddr("100::1"),
|
||||||
|
netip.MustParseAddr("100::3"),
|
||||||
|
netip.MustParsePrefix("100::/64"),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allocate all available IPs
|
for _, realIP := range []string{"1.0.0.1", "1.0.0.2", "1.0.0.3"} {
|
||||||
realIPs := []netip.Addr{
|
_, err := manager.AllocateFakeIP(netip.MustParseAddr(realIP))
|
||||||
netip.MustParseAddr("1.0.0.1"),
|
|
||||||
netip.MustParseAddr("1.0.0.2"),
|
|
||||||
netip.MustParseAddr("1.0.0.3"),
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, realIP := range realIPs {
|
|
||||||
_, err := manager.AllocateFakeIP(realIP)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to allocate fake IP: %v", err)
|
t.Fatalf("Failed to allocate fake IP: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to allocate one more - should fail
|
|
||||||
_, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.4"))
|
_, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.4"))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("Expected exhaustion error")
|
t.Error("Expected v4 exhaustion error")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Same for v6
|
||||||
|
for _, realIP := range []string{"2001:db8::1", "2001:db8::2", "2001:db8::3"} {
|
||||||
|
_, err := manager.AllocateFakeIP(netip.MustParseAddr(realIP))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to allocate fake IPv6: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = manager.AllocateFakeIP(netip.MustParseAddr("2001:db8::4"))
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected v6 exhaustion error")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWrapAround(t *testing.T) {
|
func TestWrapAround(t *testing.T) {
|
||||||
// Create manager starting near the end of range
|
|
||||||
manager := &Manager{
|
manager := &Manager{
|
||||||
nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}),
|
v4: newPool(
|
||||||
allocated: make(map[netip.Addr]netip.Addr),
|
netip.AddrFrom4([4]byte{240, 0, 0, 1}),
|
||||||
fakeToReal: make(map[netip.Addr]netip.Addr),
|
netip.AddrFrom4([4]byte{240, 0, 0, 254}),
|
||||||
baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
|
netip.MustParsePrefix("240.0.0.0/8"),
|
||||||
maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}),
|
),
|
||||||
|
v6: newPool(
|
||||||
|
netip.MustParseAddr("100::1"),
|
||||||
|
netip.MustParseAddr("100::ffff:ffff:ffff:fffe"),
|
||||||
|
netip.MustParsePrefix("100::/64"),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
// Start near the end
|
||||||
|
manager.v4.nextIP = netip.AddrFrom4([4]byte{240, 0, 0, 254})
|
||||||
|
|
||||||
// Allocate the last IP
|
|
||||||
fakeIP1, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.1"))
|
fakeIP1, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.1"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to allocate first IP: %v", err)
|
t.Fatalf("Failed to allocate first IP: %v", err)
|
||||||
@@ -228,7 +273,6 @@ func TestWrapAround(t *testing.T) {
|
|||||||
t.Errorf("Expected 240.0.0.254, got %s", fakeIP1.String())
|
t.Errorf("Expected 240.0.0.254, got %s", fakeIP1.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Next allocation should wrap around to the beginning
|
|
||||||
fakeIP2, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.2"))
|
fakeIP2, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.2"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to allocate second IP: %v", err)
|
t.Fatalf("Failed to allocate second IP: %v", err)
|
||||||
@@ -238,3 +282,32 @@ func TestWrapAround(t *testing.T) {
|
|||||||
t.Errorf("Expected 240.0.0.1 after wrap, got %s", fakeIP2.String())
|
t.Errorf("Expected 240.0.0.1 after wrap, got %s", fakeIP2.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMixedV4V6(t *testing.T) {
|
||||||
|
manager := NewManager()
|
||||||
|
|
||||||
|
v4Fake, err := manager.AllocateFakeIP(netip.MustParseAddr("8.8.8.8"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to allocate v4: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
v6Fake, err := manager.AllocateFakeIP(netip.MustParseAddr("2001:db8::1"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to allocate v6: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !v4Fake.Is4() || !v6Fake.Is6() {
|
||||||
|
t.Errorf("Wrong families: v4=%s v6=%s", v4Fake, v6Fake)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reverse lookups should work for both
|
||||||
|
gotV4, ok := manager.GetRealIP(v4Fake)
|
||||||
|
if !ok || gotV4.String() != "8.8.8.8" {
|
||||||
|
t.Errorf("v4 reverse lookup failed: got %s, ok=%v", gotV4, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
gotV6, ok := manager.GetRealIP(v6Fake)
|
||||||
|
if !ok || gotV6.String() != "2001:db8::1" {
|
||||||
|
t.Errorf("v6 reverse lookup failed: got %s, ok=%v", gotV6, ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,7 +9,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// IPForwardingState is a struct that keeps track of the IP forwarding state.
|
// IPForwardingState is a struct that keeps track of the IP forwarding state.
|
||||||
// todo: read initial state of the IP forwarding from the system and reset the state based on it
|
// todo: read initial state of the IP forwarding from the system and reset the state based on it.
|
||||||
|
// todo: separate v4/v6 forwarding state, since the sysctls are independent
|
||||||
|
// (net.ipv4.ip_forward vs net.ipv6.conf.all.forwarding). Currently the nftables
|
||||||
|
// manager shares one instance between both routers, which works only because
|
||||||
|
// EnableIPForwarding enables both sysctls in a single call.
|
||||||
type IPForwardingState struct {
|
type IPForwardingState struct {
|
||||||
enabledCounter int
|
enabledCounter int
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -159,15 +159,23 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
|
|||||||
if config.DNSFeatureFlag {
|
if config.DNSFeatureFlag {
|
||||||
m.fakeIPManager = fakeip.NewManager()
|
m.fakeIPManager = fakeip.NewManager()
|
||||||
|
|
||||||
id := uuid.NewString()
|
v4ID := uuid.NewString()
|
||||||
fakeIPRoute := &route.Route{
|
cr = append(cr, &route.Route{
|
||||||
ID: route.ID(id),
|
ID: route.ID(v4ID),
|
||||||
Network: m.fakeIPManager.GetFakeIPBlock(),
|
Network: m.fakeIPManager.GetFakeIPBlock(),
|
||||||
NetID: route.NetID(id),
|
NetID: route.NetID(v4ID),
|
||||||
Peer: m.pubKey,
|
Peer: m.pubKey,
|
||||||
NetworkType: route.IPv4Network,
|
NetworkType: route.IPv4Network,
|
||||||
}
|
})
|
||||||
cr = append(cr, fakeIPRoute)
|
|
||||||
|
v6ID := uuid.NewString()
|
||||||
|
cr = append(cr, &route.Route{
|
||||||
|
ID: route.ID(v6ID),
|
||||||
|
Network: m.fakeIPManager.GetFakeIPv6Block(),
|
||||||
|
NetID: route.NetID(v6ID),
|
||||||
|
Peer: m.pubKey,
|
||||||
|
NetworkType: route.IPv6Network,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
m.notifier.SetInitialClientRoutes(cr, routesForComparison)
|
m.notifier.SetInitialClientRoutes(cr, routesForComparison)
|
||||||
|
|||||||
@@ -146,8 +146,7 @@ func routeToRouterPair(route *route.Route, useNewDNSRoute bool) firewall.RouterP
|
|||||||
if useNewDNSRoute {
|
if useNewDNSRoute {
|
||||||
destination.Set = firewall.NewDomainSet(route.Domains)
|
destination.Set = firewall.NewDomainSet(route.Domains)
|
||||||
} else {
|
} else {
|
||||||
// TODO: add ipv6 additionally
|
destination = getDefaultPrefix(route.Network)
|
||||||
destination = getDefaultPrefix(destination.Prefix)
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
destination.Prefix = route.Network.Masked()
|
destination.Prefix = route.Network.Masked()
|
||||||
|
|||||||
@@ -107,8 +107,13 @@ func (r *SysOps) validateRoute(prefix netip.Prefix) error {
|
|||||||
addr.IsInterfaceLocalMulticast(),
|
addr.IsInterfaceLocalMulticast(),
|
||||||
addr.IsMulticast(),
|
addr.IsMulticast(),
|
||||||
addr.IsUnspecified() && prefix.Bits() != 0,
|
addr.IsUnspecified() && prefix.Bits() != 0,
|
||||||
r.wgInterface.Address().Network.Contains(addr):
|
r.isOwnAddress(addr):
|
||||||
return vars.ErrRouteNotAllowed
|
return vars.ErrRouteNotAllowed
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *SysOps) isOwnAddress(addr netip.Addr) bool {
|
||||||
|
wgAddr := r.wgInterface.Address()
|
||||||
|
return wgAddr.Network.Contains(addr) || (wgAddr.IPv6Net.IsValid() && wgAddr.IPv6Net.Contains(addr))
|
||||||
|
}
|
||||||
|
|||||||
@@ -222,30 +222,20 @@ func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) er
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: remove once IPv6 is supported on the interface
|
// When the interface has no v6, add v6 split-default as blackhole so
|
||||||
if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil {
|
// unroutable v6 goes to WG (dropped, no AllowedIPs) instead of leaking
|
||||||
return fmt.Errorf("add unreachable route split 1: %w", err)
|
// to the system default route. When v6 is active, management sends ::/0
|
||||||
}
|
// as a separate route that the dedicated handler adds.
|
||||||
if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil {
|
// Soft-fail: v6 blackhole is best-effort, don't abort v4 routing on failure.
|
||||||
if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil {
|
if !r.wgInterface.Address().HasIPv6() {
|
||||||
log.Warnf("Failed to rollback route addition: %s", err2)
|
if err := r.addV6SplitDefault(nextHop); err != nil {
|
||||||
|
log.Warnf("failed to add v6 split-default blackhole: %s", err)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("add unreachable route split 2: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
case vars.Defaultv6:
|
case vars.Defaultv6:
|
||||||
if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil {
|
return r.addV6SplitDefault(nextHop)
|
||||||
return fmt.Errorf("add unreachable route split 1: %w", err)
|
|
||||||
}
|
|
||||||
if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil {
|
|
||||||
if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil {
|
|
||||||
log.Warnf("Failed to rollback route addition: %s", err2)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("add unreachable route split 2: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.addToRouteTable(prefix, nextHop)
|
return r.addToRouteTable(prefix, nextHop)
|
||||||
@@ -266,30 +256,42 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
|
|||||||
result = multierror.Append(result, err)
|
result = multierror.Append(result, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: remove once IPv6 is supported on the interface
|
if !r.wgInterface.Address().HasIPv6() {
|
||||||
if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil {
|
result = multierror.Append(result, r.removeV6SplitDefault(nextHop))
|
||||||
result = multierror.Append(result, err)
|
|
||||||
}
|
|
||||||
if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil {
|
|
||||||
result = multierror.Append(result, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(result)
|
return nberrors.FormatErrorOrNil(result)
|
||||||
case vars.Defaultv6:
|
case vars.Defaultv6:
|
||||||
var result *multierror.Error
|
return nberrors.FormatErrorOrNil(r.removeV6SplitDefault(nextHop))
|
||||||
if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil {
|
|
||||||
result = multierror.Append(result, err)
|
|
||||||
}
|
|
||||||
if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil {
|
|
||||||
result = multierror.Append(result, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(result)
|
|
||||||
default:
|
default:
|
||||||
return r.removeFromRouteTable(prefix, nextHop)
|
return r.removeFromRouteTable(prefix, nextHop)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *SysOps) addV6SplitDefault(nextHop Nexthop) error {
|
||||||
|
if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil {
|
||||||
|
return fmt.Errorf("add split 1: %w", err)
|
||||||
|
}
|
||||||
|
if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil {
|
||||||
|
if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil {
|
||||||
|
log.Warnf("Failed to rollback v6 split-default: %s", err2)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("add split 2: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *SysOps) removeV6SplitDefault(nextHop Nexthop) *multierror.Error {
|
||||||
|
var result *multierror.Error
|
||||||
|
if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||||
beforeHook := func(connID hooks.ConnectionID, prefix netip.Prefix) error {
|
beforeHook := func(connID hooks.ConnectionID, prefix netip.Prefix) error {
|
||||||
if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil {
|
if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil {
|
||||||
|
|||||||
@@ -53,6 +53,8 @@ const (
|
|||||||
|
|
||||||
// ipv4ForwardingPath is the path to the file containing the IP forwarding setting.
|
// ipv4ForwardingPath is the path to the file containing the IP forwarding setting.
|
||||||
ipv4ForwardingPath = "net.ipv4.ip_forward"
|
ipv4ForwardingPath = "net.ipv4.ip_forward"
|
||||||
|
// ipv6ForwardingPath is the path to the file containing the IPv6 forwarding setting.
|
||||||
|
ipv6ForwardingPath = "net.ipv6.conf.all.forwarding"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrTableIDExists = errors.New("ID exists with different name")
|
var ErrTableIDExists = errors.New("ID exists with different name")
|
||||||
@@ -185,10 +187,11 @@ func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
|||||||
|
|
||||||
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 1
|
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 1
|
||||||
|
|
||||||
// TODO remove this once we have ipv6 support
|
// When the peer has no IPv6, blackhole v6 to prevent leaking.
|
||||||
if prefix == vars.Defaultv4 {
|
// When IPv6 is enabled, management sends ::/0 as a separate route.
|
||||||
|
if prefix == vars.Defaultv4 && (r.wgInterface == nil || !r.wgInterface.Address().HasIPv6()) {
|
||||||
if err := addUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil {
|
if err := addUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil {
|
||||||
return fmt.Errorf("add blackhole: %w", err)
|
return fmt.Errorf("add v6 blackhole: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := addRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil {
|
if err := addRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil {
|
||||||
@@ -206,10 +209,9 @@ func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error
|
|||||||
return r.genericRemoveVPNRoute(prefix, intf)
|
return r.genericRemoveVPNRoute(prefix, intf)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO remove this once we have ipv6 support
|
if prefix == vars.Defaultv4 && (r.wgInterface == nil || !r.wgInterface.Address().HasIPv6()) {
|
||||||
if prefix == vars.Defaultv4 {
|
|
||||||
if err := removeUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil {
|
if err := removeUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil {
|
||||||
return fmt.Errorf("remove unreachable route: %w", err)
|
log.Debugf("remove v6 blackhole: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := removeRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil {
|
if err := removeRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil {
|
||||||
@@ -762,8 +764,13 @@ func flushRoutes(tableID, family int) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func EnableIPForwarding() error {
|
func EnableIPForwarding() error {
|
||||||
_, err := sysctl.Set(ipv4ForwardingPath, 1, false)
|
if _, err := sysctl.Set(ipv4ForwardingPath, 1, false); err != nil {
|
||||||
return err
|
return err
|
||||||
|
}
|
||||||
|
if _, err := sysctl.Set(ipv6ForwardingPath, 1, false); err != nil {
|
||||||
|
log.Warnf("failed to enable IPv6 forwarding: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// entryExists checks if the specified ID or name already exists in the rt_tables file
|
// entryExists checks if the specified ID or name already exists in the rt_tables file
|
||||||
|
|||||||
@@ -50,10 +50,11 @@ type CustomLogger interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type selectRoute struct {
|
type selectRoute struct {
|
||||||
NetID string
|
NetID string
|
||||||
Network netip.Prefix
|
Network netip.Prefix
|
||||||
Domains domain.List
|
Domains domain.List
|
||||||
Selected bool
|
Selected bool
|
||||||
|
extraNetworks []netip.Prefix
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@@ -363,48 +364,60 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
routeManager := engine.GetRouteManager()
|
routeManager := engine.GetRouteManager()
|
||||||
routesMap := routeManager.GetClientRoutesWithNetID()
|
|
||||||
if routeManager == nil {
|
if routeManager == nil {
|
||||||
return nil, fmt.Errorf("could not get route manager")
|
return nil, fmt.Errorf("could not get route manager")
|
||||||
}
|
}
|
||||||
|
routesMap := routeManager.GetClientRoutesWithNetID()
|
||||||
routeSelector := routeManager.GetRouteSelector()
|
routeSelector := routeManager.GetRouteSelector()
|
||||||
if routeSelector == nil {
|
if routeSelector == nil {
|
||||||
return nil, fmt.Errorf("could not get route selector")
|
return nil, fmt.Errorf("could not get route selector")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
v6ExitMerged := route.V6ExitMergeSet(routesMap)
|
||||||
|
routes := buildSelectRoutes(routesMap, routeSelector.IsSelected, v6ExitMerged)
|
||||||
|
resolvedDomains := c.recorder.GetResolvedDomainsStates()
|
||||||
|
|
||||||
|
return prepareRouteSelectionDetails(routes, resolvedDomains), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildSelectRoutes(routesMap map[route.NetID][]*route.Route, isSelected func(route.NetID) bool, v6Merged map[route.NetID]struct{}) []*selectRoute {
|
||||||
var routes []*selectRoute
|
var routes []*selectRoute
|
||||||
for id, rt := range routesMap {
|
for id, rt := range routesMap {
|
||||||
if len(rt) == 0 {
|
if len(rt) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
route := &selectRoute{
|
if _, ok := v6Merged[id]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
r := &selectRoute{
|
||||||
NetID: string(id),
|
NetID: string(id),
|
||||||
Network: rt[0].Network,
|
Network: rt[0].Network,
|
||||||
Domains: rt[0].Domains,
|
Domains: rt[0].Domains,
|
||||||
Selected: routeSelector.IsSelected(id),
|
Selected: isSelected(id),
|
||||||
}
|
}
|
||||||
routes = append(routes, route)
|
|
||||||
|
v6ID := route.NetID(string(id) + route.V6ExitSuffix)
|
||||||
|
if _, ok := v6Merged[v6ID]; ok {
|
||||||
|
r.extraNetworks = []netip.Prefix{routesMap[v6ID][0].Network}
|
||||||
|
}
|
||||||
|
|
||||||
|
routes = append(routes, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
sort.Slice(routes, func(i, j int) bool {
|
sort.Slice(routes, func(i, j int) bool {
|
||||||
iPrefix := routes[i].Network.Bits()
|
iBits, jBits := routes[i].Network.Bits(), routes[j].Network.Bits()
|
||||||
jPrefix := routes[j].Network.Bits()
|
if iBits != jBits {
|
||||||
|
return iBits < jBits
|
||||||
if iPrefix == jPrefix {
|
|
||||||
iAddr := routes[i].Network.Addr()
|
|
||||||
jAddr := routes[j].Network.Addr()
|
|
||||||
if iAddr == jAddr {
|
|
||||||
return routes[i].NetID < routes[j].NetID
|
|
||||||
}
|
|
||||||
return iAddr.String() < jAddr.String()
|
|
||||||
}
|
}
|
||||||
return iPrefix < jPrefix
|
iAddr, jAddr := routes[i].Network.Addr(), routes[j].Network.Addr()
|
||||||
|
if iAddr != jAddr {
|
||||||
|
return iAddr.Less(jAddr)
|
||||||
|
}
|
||||||
|
return routes[i].NetID < routes[j].NetID
|
||||||
})
|
})
|
||||||
|
|
||||||
resolvedDomains := c.recorder.GetResolvedDomainsStates()
|
return routes
|
||||||
|
|
||||||
return prepareRouteSelectionDetails(routes, resolvedDomains), nil
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) *RoutesSelectionDetails {
|
func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) *RoutesSelectionDetails {
|
||||||
@@ -425,10 +438,15 @@ func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[dom
|
|||||||
}
|
}
|
||||||
domainList = append(domainList, domainResp)
|
domainList = append(domainList, domainResp)
|
||||||
}
|
}
|
||||||
|
rangeStr := r.Network.String()
|
||||||
|
for _, extra := range r.extraNetworks {
|
||||||
|
rangeStr += ", " + extra.String()
|
||||||
|
}
|
||||||
|
|
||||||
domainDetails := DomainDetails{items: domainList}
|
domainDetails := DomainDetails{items: domainList}
|
||||||
routeSelection = append(routeSelection, RoutesSelectionInfo{
|
routeSelection = append(routeSelection, RoutesSelectionInfo{
|
||||||
ID: r.NetID,
|
ID: r.NetID,
|
||||||
Network: r.Network.String(),
|
Network: rangeStr,
|
||||||
Domains: &domainDetails,
|
Domains: &domainDetails,
|
||||||
Selected: r.Selected,
|
Selected: r.Selected,
|
||||||
})
|
})
|
||||||
@@ -456,7 +474,9 @@ func (c *Client) SelectRoute(id string) error {
|
|||||||
} else {
|
} else {
|
||||||
log.Debugf("select route with id: %s", id)
|
log.Debugf("select route with id: %s", id)
|
||||||
routes := toNetIDs([]string{id})
|
routes := toNetIDs([]string{id})
|
||||||
if err := routeSelector.SelectRoutes(routes, true, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil {
|
routesMap := routeManager.GetClientRoutesWithNetID()
|
||||||
|
routes = route.ExpandV6ExitPairs(routes, routesMap)
|
||||||
|
if err := routeSelector.SelectRoutes(routes, true, maps.Keys(routesMap)); err != nil {
|
||||||
log.Debugf("error when selecting routes: %s", err)
|
log.Debugf("error when selecting routes: %s", err)
|
||||||
return fmt.Errorf("select routes: %w", err)
|
return fmt.Errorf("select routes: %w", err)
|
||||||
}
|
}
|
||||||
@@ -483,7 +503,9 @@ func (c *Client) DeselectRoute(id string) error {
|
|||||||
} else {
|
} else {
|
||||||
log.Debugf("deselect route with id: %s", id)
|
log.Debugf("deselect route with id: %s", id)
|
||||||
routes := toNetIDs([]string{id})
|
routes := toNetIDs([]string{id})
|
||||||
if err := routeSelector.DeselectRoutes(routes, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil {
|
routesMap := routeManager.GetClientRoutesWithNetID()
|
||||||
|
routes = route.ExpandV6ExitPairs(routes, routesMap)
|
||||||
|
if err := routeSelector.DeselectRoutes(routes, maps.Keys(routesMap)); err != nil {
|
||||||
log.Debugf("error when deselecting routes: %s", err)
|
log.Debugf("error when deselecting routes: %s", err)
|
||||||
return fmt.Errorf("deselect routes: %w", err)
|
return fmt.Errorf("deselect routes: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,10 +16,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type selectRoute struct {
|
type selectRoute struct {
|
||||||
NetID route.NetID
|
NetID route.NetID
|
||||||
Network netip.Prefix
|
Network netip.Prefix
|
||||||
Domains domain.List
|
Domains domain.List
|
||||||
Selected bool
|
Selected bool
|
||||||
|
extraNetworks []netip.Prefix
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListNetworks returns a list of all available networks.
|
// ListNetworks returns a list of all available networks.
|
||||||
@@ -44,18 +45,33 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
|
|||||||
routesMap := routeMgr.GetClientRoutesWithNetID()
|
routesMap := routeMgr.GetClientRoutesWithNetID()
|
||||||
routeSelector := routeMgr.GetRouteSelector()
|
routeSelector := routeMgr.GetRouteSelector()
|
||||||
|
|
||||||
|
v6ExitMerged := route.V6ExitMergeSet(routesMap)
|
||||||
|
|
||||||
var routes []*selectRoute
|
var routes []*selectRoute
|
||||||
for id, rt := range routesMap {
|
for id, rt := range routesMap {
|
||||||
if len(rt) == 0 {
|
if len(rt) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
route := &selectRoute{
|
// Skip v6 exit nodes that are merged into their v4 counterpart.
|
||||||
|
if _, ok := v6ExitMerged[id]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
r := &selectRoute{
|
||||||
NetID: id,
|
NetID: id,
|
||||||
Network: rt[0].Network,
|
Network: rt[0].Network,
|
||||||
Domains: rt[0].Domains,
|
Domains: rt[0].Domains,
|
||||||
Selected: routeSelector.IsSelected(id),
|
Selected: routeSelector.IsSelected(id),
|
||||||
}
|
}
|
||||||
routes = append(routes, route)
|
|
||||||
|
// Merge paired v6 exit node prefix into this entry.
|
||||||
|
v6ID := route.NetID(string(id) + route.V6ExitSuffix)
|
||||||
|
if _, ok := v6ExitMerged[v6ID]; ok {
|
||||||
|
v6Prefix := routesMap[v6ID][0].Network
|
||||||
|
r.extraNetworks = []netip.Prefix{v6Prefix}
|
||||||
|
}
|
||||||
|
|
||||||
|
routes = append(routes, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
sort.Slice(routes, func(i, j int) bool {
|
sort.Slice(routes, func(i, j int) bool {
|
||||||
@@ -76,9 +92,13 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
|
|||||||
resolvedDomains := s.statusRecorder.GetResolvedDomainsStates()
|
resolvedDomains := s.statusRecorder.GetResolvedDomainsStates()
|
||||||
var pbRoutes []*proto.Network
|
var pbRoutes []*proto.Network
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
|
rangeStr := route.Network.String()
|
||||||
|
for _, extra := range route.extraNetworks {
|
||||||
|
rangeStr += ", " + extra.String()
|
||||||
|
}
|
||||||
pbRoute := &proto.Network{
|
pbRoute := &proto.Network{
|
||||||
ID: string(route.NetID),
|
ID: string(route.NetID),
|
||||||
Range: route.Network.String(),
|
Range: rangeStr,
|
||||||
Domains: route.Domains.ToSafeStringList(),
|
Domains: route.Domains.ToSafeStringList(),
|
||||||
ResolvedIPs: map[string]*proto.IPList{},
|
ResolvedIPs: map[string]*proto.IPList{},
|
||||||
Selected: route.Selected,
|
Selected: route.Selected,
|
||||||
@@ -137,7 +157,9 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ
|
|||||||
routeSelector.SelectAllRoutes()
|
routeSelector.SelectAllRoutes()
|
||||||
} else {
|
} else {
|
||||||
routes := toNetIDs(req.GetNetworkIDs())
|
routes := toNetIDs(req.GetNetworkIDs())
|
||||||
netIdRoutes := maps.Keys(routeManager.GetClientRoutesWithNetID())
|
routesMap := routeManager.GetClientRoutesWithNetID()
|
||||||
|
routes = route.ExpandV6ExitPairs(routes, routesMap)
|
||||||
|
netIdRoutes := maps.Keys(routesMap)
|
||||||
if err := routeSelector.SelectRoutes(routes, req.GetAppend(), netIdRoutes); err != nil {
|
if err := routeSelector.SelectRoutes(routes, req.GetAppend(), netIdRoutes); err != nil {
|
||||||
return nil, fmt.Errorf("select routes: %w", err)
|
return nil, fmt.Errorf("select routes: %w", err)
|
||||||
}
|
}
|
||||||
@@ -183,7 +205,9 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe
|
|||||||
routeSelector.DeselectAllRoutes()
|
routeSelector.DeselectAllRoutes()
|
||||||
} else {
|
} else {
|
||||||
routes := toNetIDs(req.GetNetworkIDs())
|
routes := toNetIDs(req.GetNetworkIDs())
|
||||||
netIdRoutes := maps.Keys(routeManager.GetClientRoutesWithNetID())
|
routesMap := routeManager.GetClientRoutesWithNetID()
|
||||||
|
routes = route.ExpandV6ExitPairs(routes, routesMap)
|
||||||
|
netIdRoutes := maps.Keys(routesMap)
|
||||||
if err := routeSelector.DeselectRoutes(routes, netIdRoutes); err != nil {
|
if err := routeSelector.DeselectRoutes(routes, netIdRoutes); err != nil {
|
||||||
return nil, fmt.Errorf("deselect routes: %w", err)
|
return nil, fmt.Errorf("deselect routes: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -195,7 +195,7 @@ func getOverlappingNetworks(routes []*proto.Network) []*proto.Network {
|
|||||||
func getExitNodeNetworks(routes []*proto.Network) []*proto.Network {
|
func getExitNodeNetworks(routes []*proto.Network) []*proto.Network {
|
||||||
var filteredRoutes []*proto.Network
|
var filteredRoutes []*proto.Network
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
if route.Range == "0.0.0.0/0" || route.Range == "::/0" {
|
if strings.Contains(route.Range, "0.0.0.0/0") || route.Range == "::/0" {
|
||||||
filteredRoutes = append(filteredRoutes, route)
|
filteredRoutes = append(filteredRoutes, route)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"strconv"
|
||||||
"syscall/js"
|
"syscall/js"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -166,39 +167,58 @@ func createSSHMethod(client *netbird.Client) js.Func {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
var jwtToken string
|
jwtToken, ipVersion := parseSSHOptions(args)
|
||||||
if len(args) > 3 && !args[3].IsNull() && !args[3].IsUndefined() {
|
|
||||||
jwtToken = args[3].String()
|
|
||||||
}
|
|
||||||
|
|
||||||
return createPromise(func(resolve, reject js.Value) {
|
return createPromise(func(resolve, reject js.Value) {
|
||||||
sshClient := ssh.NewClient(client)
|
jsInterface, err := connectSSH(client, host, port, username, jwtToken, ipVersion)
|
||||||
|
if err != nil {
|
||||||
if err := sshClient.Connect(host, port, username, jwtToken); err != nil {
|
|
||||||
reject.Invoke(err.Error())
|
reject.Invoke(err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := sshClient.StartSession(80, 24); err != nil {
|
|
||||||
if closeErr := sshClient.Close(); closeErr != nil {
|
|
||||||
log.Errorf("Error closing SSH client: %v", closeErr)
|
|
||||||
}
|
|
||||||
reject.Invoke(err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
jsInterface := ssh.CreateJSInterface(sshClient)
|
|
||||||
resolve.Invoke(jsInterface)
|
resolve.Invoke(jsInterface)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func performPing(client *netbird.Client, hostname string) {
|
func parseSSHOptions(args []js.Value) (jwtToken string, ipVersion int) {
|
||||||
|
if len(args) > 3 && !args[3].IsNull() && !args[3].IsUndefined() {
|
||||||
|
jwtToken = args[3].String()
|
||||||
|
}
|
||||||
|
if len(args) > 4 {
|
||||||
|
ipVersion = jsIPVersion(args[4])
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func connectSSH(client *netbird.Client, host string, port int, username, jwtToken string, ipVersion int) (js.Value, error) {
|
||||||
|
sshClient := ssh.NewClient(client)
|
||||||
|
|
||||||
|
if err := sshClient.Connect(host, port, username, jwtToken, ipVersion); err != nil {
|
||||||
|
return js.Undefined(), err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := sshClient.StartSession(80, 24); err != nil {
|
||||||
|
if closeErr := sshClient.Close(); closeErr != nil {
|
||||||
|
log.Errorf("Error closing SSH client: %v", closeErr)
|
||||||
|
}
|
||||||
|
return js.Undefined(), err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ssh.CreateJSInterface(sshClient), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func performPing(client *netbird.Client, hostname string, ipVersion int) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
// Default to ping4 to avoid dual-stack ICMP endpoint issues in wireguard-go netstack.
|
||||||
|
network := "ping4"
|
||||||
|
if ipVersion == 6 {
|
||||||
|
network = "ping6"
|
||||||
|
}
|
||||||
|
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
conn, err := client.Dial(ctx, "ping", hostname)
|
conn, err := client.Dial(ctx, network, hostname)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s failed: %v", hostname, err))
|
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s failed: %v", hostname, err))
|
||||||
return
|
return
|
||||||
@@ -225,27 +245,39 @@ func performPing(client *netbird.Client, hostname string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
latency := time.Since(start)
|
latency := time.Since(start)
|
||||||
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s: %dms", hostname, latency.Milliseconds()))
|
remote := conn.RemoteAddr().String()
|
||||||
|
msg := fmt.Sprintf("Ping to %s: %dms", hostname, latency.Milliseconds())
|
||||||
|
if remote != hostname {
|
||||||
|
msg += fmt.Sprintf(" (via %s)", remote)
|
||||||
|
}
|
||||||
|
js.Global().Get("console").Call("log", msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func performPingTCP(client *netbird.Client, hostname string, port int) {
|
func performPingTCP(client *netbird.Client, hostname string, port, ipVersion int) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
network := ipVersionNetwork("tcp", ipVersion)
|
||||||
|
|
||||||
address := net.JoinHostPort(hostname, fmt.Sprintf("%d", port))
|
address := net.JoinHostPort(hostname, fmt.Sprintf("%d", port))
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
conn, err := client.Dial(ctx, "tcp", address)
|
conn, err := client.Dial(ctx, network, address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s failed: %v", address, err))
|
js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s failed: %v", address, err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
latency := time.Since(start)
|
latency := time.Since(start)
|
||||||
|
|
||||||
|
remote := conn.RemoteAddr().String()
|
||||||
if err := conn.Close(); err != nil {
|
if err := conn.Close(); err != nil {
|
||||||
log.Debugf("failed to close TCP connection: %v", err)
|
log.Debugf("failed to close TCP connection: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s succeeded: %dms", address, latency.Milliseconds()))
|
msg := fmt.Sprintf("TCP ping to %s succeeded: %dms", address, latency.Milliseconds())
|
||||||
|
if remote != address {
|
||||||
|
msg += fmt.Sprintf(" (via %s)", remote)
|
||||||
|
}
|
||||||
|
js.Global().Get("console").Call("log", msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// createPingMethod creates the ping method
|
// createPingMethod creates the ping method
|
||||||
@@ -262,8 +294,12 @@ func createPingMethod(client *netbird.Client) js.Func {
|
|||||||
}
|
}
|
||||||
|
|
||||||
hostname := args[0].String()
|
hostname := args[0].String()
|
||||||
|
var ipVersion int
|
||||||
|
if len(args) > 1 {
|
||||||
|
ipVersion = jsIPVersion(args[1])
|
||||||
|
}
|
||||||
return createPromise(func(resolve, reject js.Value) {
|
return createPromise(func(resolve, reject js.Value) {
|
||||||
performPing(client, hostname)
|
performPing(client, hostname, ipVersion)
|
||||||
resolve.Invoke(js.Undefined())
|
resolve.Invoke(js.Undefined())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@@ -290,8 +326,12 @@ func createPingTCPMethod(client *netbird.Client) js.Func {
|
|||||||
|
|
||||||
hostname := args[0].String()
|
hostname := args[0].String()
|
||||||
port := args[1].Int()
|
port := args[1].Int()
|
||||||
|
var ipVersion int
|
||||||
|
if len(args) > 2 {
|
||||||
|
ipVersion = jsIPVersion(args[2])
|
||||||
|
}
|
||||||
return createPromise(func(resolve, reject js.Value) {
|
return createPromise(func(resolve, reject js.Value) {
|
||||||
performPingTCP(client, hostname, port)
|
performPingTCP(client, hostname, port, ipVersion)
|
||||||
resolve.Invoke(js.Undefined())
|
resolve.Invoke(js.Undefined())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@@ -464,6 +504,31 @@ func createSetLogLevelMethod(client *netbird.Client) js.Func {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ipVersionNetwork appends "4" or "6" to a base network string (e.g. "tcp" -> "tcp4").
|
||||||
|
func ipVersionNetwork(base string, ipVersion int) string {
|
||||||
|
switch ipVersion {
|
||||||
|
case 4:
|
||||||
|
return base + "4"
|
||||||
|
case 6:
|
||||||
|
return base + "6"
|
||||||
|
default:
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// jsIPVersion extracts an IP version (4 or 6) from a JS string or number.
|
||||||
|
func jsIPVersion(v js.Value) int {
|
||||||
|
switch v.Type() {
|
||||||
|
case js.TypeNumber:
|
||||||
|
return v.Int()
|
||||||
|
case js.TypeString:
|
||||||
|
n, _ := strconv.Atoi(v.String())
|
||||||
|
return n
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// createPromise is a helper to create JavaScript promises
|
// createPromise is a helper to create JavaScript promises
|
||||||
func createPromise(handler func(resolve, reject js.Value)) js.Value {
|
func createPromise(handler func(resolve, reject js.Value)) js.Value {
|
||||||
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
|
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
|
||||||
|
|||||||
@@ -46,8 +46,9 @@ func NewClient(nbClient *netbird.Client) *Client {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connect establishes an SSH connection through NetBird network
|
// Connect establishes an SSH connection through NetBird network.
|
||||||
func (c *Client) Connect(host string, port int, username, jwtToken string) error {
|
// ipVersion may be 4, 6, or 0 for automatic selection.
|
||||||
|
func (c *Client) Connect(host string, port int, username, jwtToken string, ipVersion int) error {
|
||||||
addr := net.JoinHostPort(host, fmt.Sprintf("%d", port))
|
addr := net.JoinHostPort(host, fmt.Sprintf("%d", port))
|
||||||
logrus.Infof("SSH: Connecting to %s as %s", addr, username)
|
logrus.Infof("SSH: Connecting to %s as %s", addr, username)
|
||||||
|
|
||||||
@@ -63,10 +64,18 @@ func (c *Client) Connect(host string, port int, username, jwtToken string) error
|
|||||||
Timeout: sshDialTimeout,
|
Timeout: sshDialTimeout,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
network := "tcp"
|
||||||
|
switch ipVersion {
|
||||||
|
case 4:
|
||||||
|
network = "tcp4"
|
||||||
|
case 6:
|
||||||
|
network = "tcp6"
|
||||||
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), sshDialTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), sshDialTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
conn, err := c.nbClient.Dial(ctx, "tcp", addr)
|
conn, err := c.nbClient.Dial(ctx, network, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("dial %s: %w", addr, err)
|
return fmt.Errorf("dial %s: %w", addr, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
@@ -57,7 +58,11 @@ var debugSyncCmd = &cobra.Command{
|
|||||||
SilenceUsage: true,
|
SilenceUsage: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
var pingTimeout string
|
var (
|
||||||
|
pingTimeout time.Duration
|
||||||
|
pingIPv4 bool
|
||||||
|
pingIPv6 bool
|
||||||
|
)
|
||||||
|
|
||||||
var debugPingCmd = &cobra.Command{
|
var debugPingCmd = &cobra.Command{
|
||||||
Use: "ping <account-id> <host> [port]",
|
Use: "ping <account-id> <host> [port]",
|
||||||
@@ -108,7 +113,10 @@ func init() {
|
|||||||
debugStatusCmd.Flags().StringVar(&statusFilterByStatus, "filter-by-status", "", "Filter by status (idle|connecting|connected)")
|
debugStatusCmd.Flags().StringVar(&statusFilterByStatus, "filter-by-status", "", "Filter by status (idle|connecting|connected)")
|
||||||
debugStatusCmd.Flags().StringVar(&statusFilterByConnectionType, "filter-by-connection-type", "", "Filter by connection type (P2P|Relayed)")
|
debugStatusCmd.Flags().StringVar(&statusFilterByConnectionType, "filter-by-connection-type", "", "Filter by connection type (P2P|Relayed)")
|
||||||
|
|
||||||
debugPingCmd.Flags().StringVar(&pingTimeout, "timeout", "", "Ping timeout (e.g., 10s)")
|
debugPingCmd.Flags().DurationVar(&pingTimeout, "timeout", 0, "Ping timeout (e.g., 10s)")
|
||||||
|
debugPingCmd.Flags().BoolVarP(&pingIPv4, "ipv4", "4", false, "Force IPv4")
|
||||||
|
debugPingCmd.Flags().BoolVarP(&pingIPv6, "ipv6", "6", false, "Force IPv6")
|
||||||
|
debugPingCmd.MarkFlagsMutuallyExclusive("ipv4", "ipv6")
|
||||||
|
|
||||||
debugCmd.AddCommand(debugHealthCmd)
|
debugCmd.AddCommand(debugHealthCmd)
|
||||||
debugCmd.AddCommand(debugClientsCmd)
|
debugCmd.AddCommand(debugClientsCmd)
|
||||||
@@ -157,7 +165,14 @@ func runDebugPing(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
port = p
|
port = p
|
||||||
}
|
}
|
||||||
return getDebugClient(cmd).PingTCP(cmd.Context(), args[0], args[1], port, pingTimeout)
|
var ipVersion string
|
||||||
|
switch {
|
||||||
|
case pingIPv4:
|
||||||
|
ipVersion = "4"
|
||||||
|
case pingIPv6:
|
||||||
|
ipVersion = "6"
|
||||||
|
}
|
||||||
|
return getDebugClient(cmd).PingTCP(cmd.Context(), args[0], args[1], port, pingTimeout, ipVersion)
|
||||||
}
|
}
|
||||||
|
|
||||||
func runDebugLogLevel(cmd *cobra.Command, args []string) error {
|
func runDebugLogLevel(cmd *cobra.Command, args []string) error {
|
||||||
|
|||||||
@@ -6,10 +6,12 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// StatusFilters contains filter options for status queries.
|
// StatusFilters contains filter options for status queries.
|
||||||
@@ -230,12 +232,16 @@ func (c *Client) ClientSyncResponse(ctx context.Context, accountID string) error
|
|||||||
}
|
}
|
||||||
|
|
||||||
// PingTCP performs a TCP ping through a client.
|
// PingTCP performs a TCP ping through a client.
|
||||||
func (c *Client) PingTCP(ctx context.Context, accountID, host string, port int, timeout string) error {
|
// ipVersion may be "4", "6", or "" for automatic.
|
||||||
|
func (c *Client) PingTCP(ctx context.Context, accountID, host string, port int, timeout time.Duration, ipVersion string) error {
|
||||||
params := url.Values{}
|
params := url.Values{}
|
||||||
params.Set("host", host)
|
params.Set("host", host)
|
||||||
params.Set("port", fmt.Sprintf("%d", port))
|
params.Set("port", fmt.Sprintf("%d", port))
|
||||||
if timeout != "" {
|
if timeout > 0 {
|
||||||
params.Set("timeout", timeout)
|
params.Set("timeout", timeout.String())
|
||||||
|
}
|
||||||
|
if ipVersion != "" {
|
||||||
|
params.Set("ip_version", ipVersion)
|
||||||
}
|
}
|
||||||
|
|
||||||
path := fmt.Sprintf("/debug/clients/%s/pingtcp?%s", url.PathEscape(accountID), params.Encode())
|
path := fmt.Sprintf("/debug/clients/%s/pingtcp?%s", url.PathEscape(accountID), params.Encode())
|
||||||
@@ -244,11 +250,17 @@ func (c *Client) PingTCP(ctx context.Context, accountID, host string, port int,
|
|||||||
|
|
||||||
func (c *Client) printPingResult(data map[string]any) {
|
func (c *Client) printPingResult(data map[string]any) {
|
||||||
success, _ := data["success"].(bool)
|
success, _ := data["success"].(bool)
|
||||||
|
host := net.JoinHostPort(fmt.Sprint(data["host"]), fmt.Sprint(data["port"]))
|
||||||
if success {
|
if success {
|
||||||
_, _ = fmt.Fprintf(c.out, "Success: %v:%v\n", data["host"], data["port"])
|
remote, _ := data["remote"].(string)
|
||||||
|
if remote != "" && remote != host {
|
||||||
|
_, _ = fmt.Fprintf(c.out, "Success: %s (via %s)\n", host, remote)
|
||||||
|
} else {
|
||||||
|
_, _ = fmt.Fprintf(c.out, "Success: %s\n", host)
|
||||||
|
}
|
||||||
_, _ = fmt.Fprintf(c.out, "Latency: %v\n", data["latency"])
|
_, _ = fmt.Fprintf(c.out, "Latency: %v\n", data["latency"])
|
||||||
} else {
|
} else {
|
||||||
_, _ = fmt.Fprintf(c.out, "Failed: %v:%v\n", data["host"], data["port"])
|
_, _ = fmt.Fprintf(c.out, "Failed: %s\n", host)
|
||||||
c.printError(data)
|
c.printError(data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"html/template"
|
"html/template"
|
||||||
"maps"
|
"maps"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -525,13 +526,18 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
network := "tcp"
|
||||||
|
if v := r.URL.Query().Get("ip_version"); v == "4" || v == "6" {
|
||||||
|
network += v
|
||||||
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), timeout)
|
ctx, cancel := context.WithTimeout(r.Context(), timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
address := fmt.Sprintf("%s:%d", host, port)
|
address := net.JoinHostPort(host, strconv.Itoa(port))
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
|
||||||
conn, err := client.Dial(ctx, "tcp", address)
|
conn, err := client.Dial(ctx, network, address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.writeJSON(w, map[string]interface{}{
|
h.writeJSON(w, map[string]interface{}{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -541,18 +547,22 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
remote := conn.RemoteAddr().String()
|
||||||
if err := conn.Close(); err != nil {
|
if err := conn.Close(); err != nil {
|
||||||
h.logger.Debugf("close tcp ping connection: %v", err)
|
h.logger.Debugf("close tcp ping connection: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
latency := time.Since(start)
|
latency := time.Since(start)
|
||||||
h.writeJSON(w, map[string]interface{}{
|
resp := map[string]interface{}{
|
||||||
"success": true,
|
"success": true,
|
||||||
"host": host,
|
"host": host,
|
||||||
"port": port,
|
"port": port,
|
||||||
|
"remote": remote,
|
||||||
"latency_ms": latency.Milliseconds(),
|
"latency_ms": latency.Milliseconds(),
|
||||||
"latency": formatDuration(latency),
|
"latency": formatDuration(latency),
|
||||||
})
|
}
|
||||||
|
h.writeJSON(w, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) handleLogLevel(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
func (h *Handler) handleLogLevel(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||||
|
|||||||
@@ -20,6 +20,9 @@ const (
|
|||||||
MaxMetric = 9999
|
MaxMetric = 9999
|
||||||
// MaxNetIDChar Max Network Identifier
|
// MaxNetIDChar Max Network Identifier
|
||||||
MaxNetIDChar = 40
|
MaxNetIDChar = 40
|
||||||
|
|
||||||
|
// V6ExitSuffix is appended to a v4 exit node NetID to form its v6 counterpart.
|
||||||
|
V6ExitSuffix = "-v6"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -215,3 +218,61 @@ func ParseNetwork(networkString string) (NetworkType, netip.Prefix, error) {
|
|||||||
|
|
||||||
return IPv4Network, masked, nil
|
return IPv4Network, masked, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
v4Default = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||||
|
v6Default = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
// IsV4DefaultRoute reports whether p is the IPv4 default route (0.0.0.0/0).
|
||||||
|
func IsV4DefaultRoute(p netip.Prefix) bool { return p == v4Default }
|
||||||
|
|
||||||
|
// IsV6DefaultRoute reports whether p is the IPv6 default route (::/0).
|
||||||
|
func IsV6DefaultRoute(p netip.Prefix) bool { return p == v6Default }
|
||||||
|
|
||||||
|
// ExpandV6ExitPairs appends the paired "-v6" exit node NetID for any v4 exit
|
||||||
|
// node (0.0.0.0/0) in ids that has a matching v6 counterpart (::/0) in routesMap.
|
||||||
|
// It modifies and returns the input slice.
|
||||||
|
func ExpandV6ExitPairs(ids []NetID, routesMap map[NetID][]*Route) []NetID {
|
||||||
|
for _, id := range ids {
|
||||||
|
rt, ok := routesMap[id]
|
||||||
|
if !ok || len(rt) == 0 || !IsV4DefaultRoute(rt[0].Network) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
v6ID := NetID(string(id) + V6ExitSuffix)
|
||||||
|
if v6Rt, ok := routesMap[v6ID]; ok && len(v6Rt) > 0 && IsV6DefaultRoute(v6Rt[0].Network) {
|
||||||
|
if !slices.Contains(ids, v6ID) {
|
||||||
|
ids = append(ids, v6ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ids
|
||||||
|
}
|
||||||
|
|
||||||
|
// V6ExitMergeSet scans routesMap and returns the set of v6 exit node NetIDs
|
||||||
|
// that should be hidden from the UI because they are paired with a v4 exit node.
|
||||||
|
// A v6 ID is paired when it has suffix "-v6", its route is ::/0, and the base
|
||||||
|
// name (without "-v6") exists with route 0.0.0.0/0.
|
||||||
|
func V6ExitMergeSet(routesMap map[NetID][]*Route) map[NetID]struct{} {
|
||||||
|
merged := make(map[NetID]struct{})
|
||||||
|
for id, rt := range routesMap {
|
||||||
|
if len(rt) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name := string(id)
|
||||||
|
if !IsV6DefaultRoute(rt[0].Network) || !strings.HasSuffix(name, V6ExitSuffix) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
baseName := NetID(strings.TrimSuffix(name, V6ExitSuffix))
|
||||||
|
if baseRt, ok := routesMap[baseName]; ok && len(baseRt) > 0 && IsV4DefaultRoute(baseRt[0].Network) {
|
||||||
|
merged[id] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return merged
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasV6ExitPair reports whether id has a paired v6 exit node in the merge set.
|
||||||
|
func HasV6ExitPair(id NetID, v6Merged map[NetID]struct{}) bool {
|
||||||
|
_, ok := v6Merged[NetID(string(id)+"-v6")]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|||||||
108
route/route_test.go
Normal file
108
route/route_test.go
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
package route
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExpandV6ExitPairs(t *testing.T) {
|
||||||
|
v4ExitRoute := &Route{Network: netip.MustParsePrefix("0.0.0.0/0")}
|
||||||
|
v6ExitRoute := &Route{Network: netip.MustParsePrefix("::/0")}
|
||||||
|
regularRoute := &Route{Network: netip.MustParsePrefix("10.0.0.0/8")}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ids []NetID
|
||||||
|
routesMap map[NetID][]*Route
|
||||||
|
expected []NetID
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "v4 exit node with matching v6 pair",
|
||||||
|
ids: []NetID{"exit-node"},
|
||||||
|
routesMap: map[NetID][]*Route{
|
||||||
|
"exit-node": {v4ExitRoute},
|
||||||
|
"exit-node-v6": {v6ExitRoute},
|
||||||
|
},
|
||||||
|
expected: []NetID{"exit-node", "exit-node-v6"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "v4 exit node without v6 pair",
|
||||||
|
ids: []NetID{"exit-node"},
|
||||||
|
routesMap: map[NetID][]*Route{
|
||||||
|
"exit-node": {v4ExitRoute},
|
||||||
|
},
|
||||||
|
expected: []NetID{"exit-node"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "regular route is not expanded",
|
||||||
|
ids: []NetID{"office"},
|
||||||
|
routesMap: map[NetID][]*Route{
|
||||||
|
"office": {regularRoute},
|
||||||
|
"office-v6": {v6ExitRoute},
|
||||||
|
},
|
||||||
|
expected: []NetID{"office"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "v6 already included is not duplicated",
|
||||||
|
ids: []NetID{"exit-node", "exit-node-v6"},
|
||||||
|
routesMap: map[NetID][]*Route{
|
||||||
|
"exit-node": {v4ExitRoute},
|
||||||
|
"exit-node-v6": {v6ExitRoute},
|
||||||
|
},
|
||||||
|
expected: []NetID{"exit-node", "exit-node-v6"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple exit nodes expanded independently",
|
||||||
|
ids: []NetID{"exit-a", "exit-b"},
|
||||||
|
routesMap: map[NetID][]*Route{
|
||||||
|
"exit-a": {v4ExitRoute},
|
||||||
|
"exit-a-v6": {v6ExitRoute},
|
||||||
|
"exit-b": {v4ExitRoute},
|
||||||
|
"exit-b-v6": {v6ExitRoute},
|
||||||
|
},
|
||||||
|
expected: []NetID{"exit-a", "exit-b", "exit-a-v6", "exit-b-v6"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "v6 suffix but not exit node network",
|
||||||
|
ids: []NetID{"office"},
|
||||||
|
routesMap: map[NetID][]*Route{
|
||||||
|
"office": {regularRoute},
|
||||||
|
"office-v6": {regularRoute},
|
||||||
|
},
|
||||||
|
expected: []NetID{"office"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user-chosen name for exit node with v6 pair",
|
||||||
|
ids: []NetID{"my-exit"},
|
||||||
|
routesMap: map[NetID][]*Route{
|
||||||
|
"my-exit": {v4ExitRoute},
|
||||||
|
"my-exit-v6": {v6ExitRoute},
|
||||||
|
},
|
||||||
|
expected: []NetID{"my-exit", "my-exit-v6"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "real-world management-generated IDs",
|
||||||
|
ids: []NetID{"0.0.0.0/0"},
|
||||||
|
routesMap: map[NetID][]*Route{
|
||||||
|
"0.0.0.0/0": {v4ExitRoute},
|
||||||
|
"0.0.0.0/0-v6": {v6ExitRoute},
|
||||||
|
},
|
||||||
|
expected: []NetID{"0.0.0.0/0", "0.0.0.0/0-v6"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty input",
|
||||||
|
ids: []NetID{},
|
||||||
|
routesMap: map[NetID][]*Route{},
|
||||||
|
expected: []NetID{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := ExpandV6ExitPairs(tt.ids, tt.routesMap)
|
||||||
|
assert.ElementsMatch(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -46,7 +46,7 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
|
|||||||
InitialPacketSize: nbRelay.QUICInitialPacketSize,
|
InitialPacketSize: nbRelay.QUICInitialPacketSize,
|
||||||
}
|
}
|
||||||
|
|
||||||
udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{Port: 0})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to listen on UDP: %s", err)
|
log.Errorf("failed to listen on UDP: %s", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
Reference in New Issue
Block a user