Compare commits

..

7 Commits

261 changed files with 4589 additions and 11432 deletions

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA
skip: go.mod,go.sum,**/proxy/web/**
golangci:
strategy:

1
.gitignore vendored
View File

@@ -33,4 +33,3 @@ infrastructure_files/setup-*.env
vendor/
/netbird
client/netbird-electron/
management/server/types/testdata/

View File

@@ -92,6 +92,9 @@ linters:
- linters:
- unused
path: client/firewall/iptables/rule\.go
- linters:
- unused
path: client/internal/dns/dnsfw/(types|syscall|zsyscall)_windows.*\.go
- linters:
- gosec
- mirror

View File

@@ -301,11 +301,10 @@ func (c *Client) PeersList() *PeerInfoArray {
peerInfos := make([]PeerInfo, len(fullStatus.Peers))
for n, p := range fullStatus.Peers {
pi := PeerInfo{
IP: p.IP,
IPv6: p.IPv6,
FQDN: p.FQDN,
ConnStatus: int(p.ConnStatus),
Routes: PeerRoutes{routes: maps.Keys(p.GetRoutes())},
p.IP,
p.FQDN,
int(p.ConnStatus),
PeerRoutes{routes: maps.Keys(p.GetRoutes())},
}
peerInfos[n] = pi
}
@@ -337,84 +336,43 @@ func (c *Client) Networks() *NetworkArray {
return nil
}
routesMap := routeManager.GetClientRoutesWithNetID()
v6Merged := route.V6ExitMergeSet(routesMap)
resolvedDomains := c.recorder.GetResolvedDomainsStates()
networkArray := &NetworkArray{
items: make([]Network, 0),
}
for id, routes := range routesMap {
resolvedDomains := c.recorder.GetResolvedDomainsStates()
for id, routes := range routeManager.GetClientRoutesWithNetID() {
if len(routes) == 0 {
continue
}
if _, skip := v6Merged[id]; skip {
continue
r := routes[0]
domains := c.getNetworkDomainsFromRoute(r, resolvedDomains)
netStr := r.Network.String()
if r.IsDynamic() {
netStr = r.Domains.SafeString()
}
network := c.buildNetwork(id, routes, routeSelector.IsSelected(id), resolvedDomains, v6Merged)
if network == nil {
routePeer, err := c.recorder.GetPeer(routes[0].Peer)
if err != nil {
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
continue
}
networkArray.Add(*network)
network := Network{
Name: string(id),
Network: netStr,
Peer: routePeer.FQDN,
Status: routePeer.ConnStatus.String(),
IsSelected: routeSelector.IsSelected(id),
Domains: domains,
}
networkArray.Add(network)
}
return networkArray
}
func (c *Client) buildNetwork(id route.NetID, routes []*route.Route, selected bool, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo, v6Merged map[route.NetID]struct{}) *Network {
r := routes[0]
netStr := r.Network.String()
if r.IsDynamic() {
netStr = r.Domains.SafeString()
}
routePeer, err := c.findBestRoutePeer(routes)
if err != nil {
log.Errorf("could not get peer info for route %s: %v", id, err)
return nil
}
network := &Network{
Name: string(id),
Network: netStr,
Peer: routePeer.FQDN,
Status: routePeer.ConnStatus.String(),
IsSelected: selected,
Domains: c.getNetworkDomainsFromRoute(r, resolvedDomains),
}
if route.IsV4DefaultRoute(r.Network) && route.HasV6ExitPair(id, v6Merged) {
network.Network = "0.0.0.0/0, ::/0"
}
return network
}
// findBestRoutePeer returns the peer actively routing traffic for the given
// HA route group. Falls back to the first connected peer, then the first peer.
func (c *Client) findBestRoutePeer(routes []*route.Route) (peer.State, error) {
netStr := routes[0].Network.String()
fullStatus := c.recorder.GetFullStatus()
for _, p := range fullStatus.Peers {
if _, ok := p.GetRoutes()[netStr]; ok {
return p, nil
}
}
for _, r := range routes {
p, err := c.recorder.GetPeer(r.Peer)
if err != nil {
continue
}
if p.ConnStatus == peer.StatusConnected {
return p, nil
}
}
return c.recorder.GetPeer(routes[0].Peer)
}
// OnUpdatedHostDNS update the DNS servers addresses for root zones
func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
dnsServer, err := dns.GetServerDns()

View File

@@ -14,7 +14,6 @@ const (
// PeerInfo describe information about the peers. It designed for the UI usage
type PeerInfo struct {
IP string
IPv6 string
FQDN string
ConnStatus int
Routes PeerRoutes

View File

@@ -307,24 +307,6 @@ func (p *Preferences) SetBlockInbound(block bool) {
p.configInput.BlockInbound = &block
}
// GetDisableIPv6 reads disable IPv6 setting from config file
func (p *Preferences) GetDisableIPv6() (bool, error) {
if p.configInput.DisableIPv6 != nil {
return *p.configInput.DisableIPv6, nil
}
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableIPv6, err
}
// SetDisableIPv6 stores the given value and waits for commit
func (p *Preferences) SetDisableIPv6(disable bool) {
p.configInput.DisableIPv6 = &disable
}
// Commit writes out the changes to the config file
func (p *Preferences) Commit() error {
_, err := profilemanager.UpdateOrCreateConfig(p.configInput)

View File

@@ -18,12 +18,9 @@ func executeRouteToggle(id string, manager routemanager.Manager,
netID := route.NetID(id)
routes := []route.NetID{netID}
routesMap := manager.GetClientRoutesWithNetID()
routes = route.ExpandV6ExitPairs(routes, routesMap)
log.Debugf("%s with id: %s", operationName, id)
log.Debugf("%s with ids: %v", operationName, routes)
if err := routeOperation(routes, maps.Keys(routesMap)); err != nil {
if err := routeOperation(routes, maps.Keys(manager.GetClientRoutesWithNetID())); err != nil {
log.Debugf("error when %s: %s", operationName, err)
return fmt.Errorf("error %s: %w", operationName, err)
}

View File

@@ -9,7 +9,6 @@ import (
"net/url"
"regexp"
"slices"
"strconv"
"strings"
)
@@ -27,9 +26,8 @@ type Anonymizer struct {
}
func DefaultAddresses() (netip.Addr, netip.Addr) {
// 198.51.100.0 (RFC 5737 TEST-NET-2), 2001:db8:ffff:: (RFC 3849 documentation, last /48)
// The old start 100:: (discard, RFC 6666) is now used for fake IPs on Android.
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.MustParseAddr("2001:db8:ffff::")
// 198.51.100.0, 100::
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01})
}
func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer {
@@ -50,7 +48,7 @@ func (a *Anonymizer) AnonymizeIP(ip netip.Addr) netip.Addr {
ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() ||
ip.IsInterfaceLocalMulticast() ||
(ip.Is4() && ip.IsPrivate()) ||
ip.IsPrivate() ||
ip.IsUnspecified() ||
ip.IsMulticast() ||
isWellKnown(ip) ||
@@ -98,11 +96,6 @@ func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
}
func (a *Anonymizer) AnonymizeIPString(ip string) string {
// Handle CIDR notation (e.g. "2001:db8::/32")
if prefix, err := netip.ParsePrefix(ip); err == nil {
return a.AnonymizeIP(prefix.Addr()).String() + "/" + strconv.Itoa(prefix.Bits())
}
addr, err := netip.ParseAddr(ip)
if err != nil {
return ip
@@ -157,7 +150,7 @@ func (a *Anonymizer) AnonymizeURI(uri string) string {
if u.Opaque != "" {
host, port, err := net.SplitHostPort(u.Opaque)
if err == nil {
anonymizedHost = net.JoinHostPort(a.AnonymizeDomain(host), port)
anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port)
} else {
anonymizedHost = a.AnonymizeDomain(u.Opaque)
}
@@ -165,7 +158,7 @@ func (a *Anonymizer) AnonymizeURI(uri string) string {
} else if u.Host != "" {
host, port, err := net.SplitHostPort(u.Host)
if err == nil {
anonymizedHost = net.JoinHostPort(a.AnonymizeDomain(host), port)
anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port)
} else {
anonymizedHost = a.AnonymizeDomain(u.Host)
}

View File

@@ -13,7 +13,7 @@ import (
func TestAnonymizeIP(t *testing.T) {
startIPv4 := netip.MustParseAddr("198.51.100.0")
startIPv6 := netip.MustParseAddr("2001:db8:ffff::")
startIPv6 := netip.MustParseAddr("100::")
anonymizer := anonymize.NewAnonymizer(startIPv4, startIPv6)
tests := []struct {
@@ -26,9 +26,9 @@ func TestAnonymizeIP(t *testing.T) {
{"Second Public IPv4", "4.3.2.1", "198.51.100.1"},
{"Repeated IPv4", "1.2.3.4", "198.51.100.0"},
{"Private IPv4", "192.168.1.1", "192.168.1.1"},
{"First Public IPv6", "2607:f8b0:4005:805::200e", "2001:db8:ffff::"},
{"Second Public IPv6", "a::b", "2001:db8:ffff::1"},
{"Repeated IPv6", "2607:f8b0:4005:805::200e", "2001:db8:ffff::"},
{"First Public IPv6", "2607:f8b0:4005:805::200e", "100::"},
{"Second Public IPv6", "a::b", "100::1"},
{"Repeated IPv6", "2607:f8b0:4005:805::200e", "100::"},
{"Private IPv6", "fe80::1", "fe80::1"},
{"In Range IPv4", "198.51.100.2", "198.51.100.2"},
}
@@ -274,27 +274,17 @@ func TestAnonymizeString_IPAddresses(t *testing.T) {
{
name: "IPv6 Address",
input: "Access attempted from 2001:db8::ff00:42",
expect: "Access attempted from 2001:db8:ffff::",
expect: "Access attempted from 100::",
},
{
name: "IPv6 Address with Port",
input: "Access attempted from [2001:db8::ff00:42]:8080",
expect: "Access attempted from [2001:db8:ffff::]:8080",
expect: "Access attempted from [100::]:8080",
},
{
name: "Both IPv4 and IPv6",
input: "IPv4: 142.108.0.1 and IPv6: 2001:db8::ff00:43",
expect: "IPv4: 198.51.100.1 and IPv6: 2001:db8:ffff::1",
},
{
name: "STUN URI with IPv6",
input: "Connecting to stun:[2001:db8::ff00:42]:3478",
expect: "Connecting to stun:[2001:db8:ffff::]:3478",
},
{
name: "HTTPS URI with IPv6",
input: "Visit https://[2001:db8::ff00:42]:443/path",
expect: "Visit https://[2001:db8:ffff::]:443/path",
expect: "IPv4: 198.51.100.1 and IPv6: 100::1",
},
}

View File

@@ -523,7 +523,7 @@ func parseHostnameAndCommand(args []string) error {
}
func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
target := net.JoinHostPort(strings.Trim(addr, "[]"), strconv.Itoa(port))
target := fmt.Sprintf("%s:%d", addr, port)
c, err := sshclient.Dial(ctx, target, username, sshclient.DialOptions{
KnownHostsFile: knownHostsFile,
IdentityFile: identityFile,
@@ -787,10 +787,10 @@ func isUnixSocket(path string) bool {
return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./")
}
// normalizeLocalHost converts "*" to "" for binding to all interfaces (dual-stack).
// normalizeLocalHost converts "*" to "0.0.0.0" for binding to all interfaces.
func normalizeLocalHost(host string) string {
if host == "*" {
return ""
return "0.0.0.0"
}
return host
}

View File

@@ -527,10 +527,10 @@ func TestParsePortForward(t *testing.T) {
{
name: "wildcard bind all interfaces",
spec: "*:8080:localhost:80",
expectedLocal: ":8080",
expectedLocal: "0.0.0.0:8080",
expectedRemote: "localhost:80",
expectError: false,
description: "Wildcard * should bind to all interfaces (dual-stack)",
description: "Wildcard * should bind to all interfaces (0.0.0.0)",
},
{
name: "wildcard for port only",

View File

@@ -20,7 +20,6 @@ import (
var (
detailFlag bool
ipv4Flag bool
ipv6Flag bool
jsonFlag bool
yamlFlag bool
ipsFilter []string
@@ -46,9 +45,8 @@ func init() {
statusCmd.PersistentFlags().BoolVar(&jsonFlag, "json", false, "display detailed status information in json format")
statusCmd.PersistentFlags().BoolVar(&yamlFlag, "yaml", false, "display detailed status information in yaml format")
statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
statusCmd.PersistentFlags().BoolVar(&ipv6Flag, "ipv6", false, "display only NetBird IPv6 of this peer")
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4", "ipv6")
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs (v4 or v6), e.g., --filter-by-ips 100.64.0.100,fd00::1")
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
@@ -103,14 +101,6 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil
}
if ipv6Flag {
ipv6 := resp.GetFullStatus().GetLocalPeerState().GetIpv6()
if ipv6 != "" {
cmd.Print(parseInterfaceIP(ipv6))
}
return nil
}
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {

View File

@@ -8,7 +8,6 @@ const (
disableFirewallFlag = "disable-firewall"
blockLANAccessFlag = "block-lan-access"
blockInboundFlag = "block-inbound"
disableIPv6Flag = "disable-ipv6"
)
var (
@@ -18,7 +17,6 @@ var (
disableFirewall bool
blockLANAccess bool
blockInbound bool
disableIPv6 bool
)
func init() {
@@ -41,7 +39,4 @@ func init() {
upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false,
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
"This overrides any policies received from the management service.")
upCmd.PersistentFlags().BoolVar(&disableIPv6, disableIPv6Flag, false,
"Disable IPv6 overlay. If enabled, the client won't request or use an IPv6 overlay address.")
}

View File

@@ -435,10 +435,6 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
req.BlockInbound = &blockInbound
}
if cmd.Flag(disableIPv6Flag).Changed {
req.DisableIpv6 = &disableIPv6
}
if cmd.Flag(enableLazyConnectionFlag).Changed {
req.LazyConnectionEnabled = &lazyConnEnabled
}
@@ -556,10 +552,6 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
ic.BlockInbound = &blockInbound
}
if cmd.Flag(disableIPv6Flag).Changed {
ic.DisableIPv6 = &disableIPv6
}
if cmd.Flag(enableLazyConnectionFlag).Changed {
ic.LazyConnectionEnabled = &lazyConnEnabled
}
@@ -674,10 +666,6 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
loginRequest.BlockInbound = &blockInbound
}
if cmd.Flag(disableIPv6Flag).Changed {
loginRequest.DisableIpv6 = &disableIPv6
}
if cmd.Flag(enableLazyConnectionFlag).Changed {
loginRequest.LazyConnectionEnabled = &lazyConnEnabled
}

View File

@@ -80,8 +80,6 @@ type Options struct {
StatePath string
// DisableClientRoutes disables the client routes
DisableClientRoutes bool
// DisableIPv6 disables IPv6 overlay addressing
DisableIPv6 bool
// BlockInbound blocks all inbound connections from peers
BlockInbound bool
// WireguardPort is the port for the tunnel interface. Use 0 for a random port.
@@ -173,7 +171,6 @@ func New(opts Options) (*Client, error) {
PreSharedKey: &opts.PreSharedKey,
DisableServerRoutes: &t,
DisableClientRoutes: &opts.DisableClientRoutes,
DisableIPv6: &opts.DisableIPv6,
BlockInbound: &opts.BlockInbound,
WireguardPort: opts.WireguardPort,
MTU: opts.MTU,

View File

@@ -40,7 +40,6 @@ type aclManager struct {
entries aclEntries
optionalEntries map[string][]entry
ipsetStore *ipsetStore
v6 bool
stateManager *statemanager.Manager
}
@@ -52,7 +51,6 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*acl
entries: make(map[string][][]string),
optionalEntries: make(map[string][]entry),
ipsetStore: newIpsetStore(),
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
}, nil
}
@@ -87,11 +85,7 @@ func (m *aclManager) AddPeerFiltering(
chain := chainNameInputRules
ipsetName = transformIPsetName(ipsetName, sPort, dPort, action)
if m.v6 && ipsetName != "" {
ipsetName += "-v6"
}
proto := protoForFamily(protocol, m.v6)
specs := filterRuleSpecs(ip, proto, sPort, dPort, action, ipsetName)
specs := filterRuleSpecs(ip, string(protocol), sPort, dPort, action, ipsetName)
mangleSpecs := slices.Clone(specs)
mangleSpecs = append(mangleSpecs,
@@ -115,7 +109,6 @@ func (m *aclManager) AddPeerFiltering(
ip: ip.String(),
chain: chain,
specs: specs,
v6: m.v6,
}}, nil
}
@@ -168,7 +161,6 @@ func (m *aclManager) AddPeerFiltering(
ipsetName: ipsetName,
ip: ip.String(),
chain: chain,
v6: m.v6,
}
m.updateState()
@@ -421,13 +413,8 @@ func (m *aclManager) updateState() {
currentState.Lock()
defer currentState.Unlock()
if m.v6 {
currentState.ACLEntries6 = m.entries
currentState.ACLIPsetStore6 = m.ipsetStore
} else {
currentState.ACLEntries = m.entries
currentState.ACLIPsetStore = m.ipsetStore
}
currentState.ACLEntries = m.entries
currentState.ACLIPsetStore = m.ipsetStore
if err := m.stateManager.UpdateState(currentState); err != nil {
log.Errorf("failed to update state: %v", err)
@@ -435,22 +422,13 @@ func (m *aclManager) updateState() {
}
// filterRuleSpecs returns the specs of a filtering rule
// protoForFamily translates ICMP to ICMPv6 for ip6tables.
// ip6tables requires "ipv6-icmp" (or "icmpv6") instead of "icmp".
func protoForFamily(protocol firewall.Protocol, v6 bool) string {
if v6 && protocol == firewall.ProtocolICMP {
return "ipv6-icmp"
}
return string(protocol)
}
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
// don't use IP matching if IP is 0.0.0.0
matchByIP := !ip.IsUnspecified()
if matchByIP {
if ipsetName != "" {
specs = append(specs, "-m", "set", "--match-set", ipsetName, "src")
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
} else {
specs = append(specs, "-s", ip.String())
}
@@ -496,9 +474,6 @@ func (m *aclManager) createIPSet(name string) error {
opts := ipset.CreateOptions{
Replace: true,
}
if m.v6 {
opts.Family = ipset.FamilyIPV6
}
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
return fmt.Errorf("create ipset %s: %w", name, err)

View File

@@ -18,10 +18,6 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager"
)
type resetter interface {
Reset() error
}
// Manager of iptables firewall
type Manager struct {
mutex sync.Mutex
@@ -32,11 +28,6 @@ type Manager struct {
aclMgr *aclManager
router *router
rawSupported bool
// IPv6 counterparts, nil when no v6 overlay
ipv6Client *iptables.IPTables
aclMgr6 *aclManager
router6 *router
}
// iFaceMapper defines subset methods of interface required for manager
@@ -67,43 +58,9 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
return nil, fmt.Errorf("create acl manager: %w", err)
}
if wgIface.Address().HasIPv6() {
if err := m.createIPv6Components(wgIface, mtu); err != nil {
return nil, fmt.Errorf("create IPv6 firewall: %w", err)
}
}
return m, nil
}
func (m *Manager) createIPv6Components(wgIface iFaceMapper, mtu uint16) error {
ip6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
if err != nil {
return fmt.Errorf("init ip6tables: %w", err)
}
m.ipv6Client = ip6Client
m.router6, err = newRouter(ip6Client, wgIface, mtu)
if err != nil {
return fmt.Errorf("create v6 router: %w", err)
}
// Share the same IP forwarding state with the v4 router, since
// EnableIPForwarding controls both v4 and v6 sysctls.
m.router6.ipFwdState = m.router.ipFwdState
m.aclMgr6, err = newAclManager(ip6Client, wgIface)
if err != nil {
return fmt.Errorf("create v6 acl manager: %w", err)
}
return nil
}
func (m *Manager) hasIPv6() bool {
return m.ipv6Client != nil
}
func (m *Manager) Init(stateManager *statemanager.Manager) error {
state := &ShutdownState{
InterfaceState: &InterfaceState{
@@ -117,8 +74,13 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
log.Errorf("failed to update state: %v", err)
}
if err := m.initChains(stateManager); err != nil {
return err
if err := m.router.init(stateManager); err != nil {
return fmt.Errorf("router init: %w", err)
}
if err := m.aclMgr.init(stateManager); err != nil {
// TODO: cleanup router
return fmt.Errorf("acl manager init: %w", err)
}
if err := m.initNoTrackChain(); err != nil {
@@ -141,41 +103,6 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
return nil
}
// initChains initializes router and ACL chains for both address families,
// rolling back on failure.
func (m *Manager) initChains(stateManager *statemanager.Manager) error {
type initStep struct {
name string
init func(*statemanager.Manager) error
mgr resetter
}
steps := []initStep{
{"router", m.router.init, m.router},
{"acl manager", m.aclMgr.init, m.aclMgr},
}
if m.hasIPv6() {
steps = append(steps,
initStep{"v6 router", m.router6.init, m.router6},
initStep{"v6 acl manager", m.aclMgr6.init, m.aclMgr6},
)
}
var initialized []initStep
for _, s := range steps {
if err := s.init(stateManager); err != nil {
for i := len(initialized) - 1; i >= 0; i-- {
if rerr := initialized[i].mgr.Reset(); rerr != nil {
log.Warnf("rollback %s: %v", initialized[i].name, rerr)
}
}
return fmt.Errorf("%s init: %w", s.name, err)
}
initialized = append(initialized, s)
}
return nil
}
// AddPeerFiltering adds a rule to the firewall
//
// Comment will be ignored because some system this feature is not supported
@@ -191,13 +118,7 @@ func (m *Manager) AddPeerFiltering(
m.mutex.Lock()
defer m.mutex.Unlock()
if ip.To4() != nil {
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
}
if !m.hasIPv6() {
return nil, fmt.Errorf("add peer filtering for %s: %w", ip, firewall.ErrIPv6NotInitialized)
}
return m.aclMgr6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
}
func (m *Manager) AddRouteFiltering(
@@ -211,48 +132,25 @@ func (m *Manager) AddRouteFiltering(
m.mutex.Lock()
defer m.mutex.Unlock()
if isIPv6RouteRule(sources, destination) {
if !m.hasIPv6() {
return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
}
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
if destination.IsPrefix() {
return destination.Prefix.Addr().Is6()
}
return len(sources) > 0 && sources[0].Addr().Is6()
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.hasIPv6() && isIPv6IptRule(rule) {
return m.aclMgr6.DeletePeerRule(rule)
}
return m.aclMgr.DeletePeerRule(rule)
}
func isIPv6IptRule(rule firewall.Rule) bool {
r, ok := rule.(*Rule)
return ok && r.v6
}
// DeleteRouteRule deletes a routing rule.
// Route rules are keyed by content hash. Check v4 first, try v6 if not found.
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.hasIPv6() && !m.router.hasRule(rule.ID()) {
return m.router6.DeleteRouteRule(rule)
}
return m.router.DeleteRouteRule(rule)
}
@@ -268,65 +166,18 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
if !m.hasIPv6() {
return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddNatRule(pair)
}
if err := m.router.AddNatRule(pair); err != nil {
return err
}
// Dynamic routes need NAT in both tables since resolved IPs can be
// either v4 or v6. This covers both DomainSet (modern) and the legacy
// wildcard 0.0.0.0/0 destination where the client resolves DNS.
if m.hasIPv6() && pair.Dynamic {
v6Pair := firewall.ToV6NatPair(pair)
if err := m.router6.AddNatRule(v6Pair); err != nil {
return fmt.Errorf("add v6 NAT rule: %w", err)
}
}
return nil
return m.router.AddNatRule(pair)
}
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
if !m.hasIPv6() {
return nil
}
return m.router6.RemoveNatRule(pair)
}
var merr *multierror.Error
if err := m.router.RemoveNatRule(pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err))
}
if m.hasIPv6() && pair.Dynamic {
v6Pair := firewall.ToV6NatPair(pair)
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
return m.router.RemoveNatRule(pair)
}
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
return err
}
if m.hasIPv6() {
return firewall.SetLegacyManagement(m.router6, isLegacy)
}
return nil
return firewall.SetLegacyManagement(m.router, isLegacy)
}
// Reset firewall to the default state
@@ -340,15 +191,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err))
}
if m.hasIPv6() {
if err := m.aclMgr6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 acl manager: %w", err))
}
if err := m.router6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %w", err))
}
}
if err := m.aclMgr.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
}
@@ -376,21 +218,24 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
// This is called when USPFilter wraps the native firewall, adding blanket accept
// rules so that packet filtering is handled in userspace instead of by netfilter.
func (m *Manager) AllowNetbird() error {
var merr *multierror.Error
if _, err := m.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
merr = multierror.Append(merr, fmt.Errorf("allow netbird v4 interface traffic: %w", err))
}
if m.hasIPv6() {
if _, err := m.AddPeerFiltering(nil, net.IPv6zero, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
merr = multierror.Append(merr, fmt.Errorf("allow netbird v6 interface traffic: %w", err))
}
_, err := m.AddPeerFiltering(
nil,
net.IP{0, 0, 0, 0},
firewall.ProtocolALL,
nil,
nil,
firewall.ActionAccept,
"",
)
if err != nil {
return fmt.Errorf("allow netbird interface traffic: %w", err)
}
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
return nberrors.FormatErrorOrNil(merr)
return nil
}
// Flush doesn't need to be implemented for this manager
@@ -420,12 +265,6 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
m.mutex.Lock()
defer m.mutex.Unlock()
if rule.TranslatedAddress.Is6() {
if !m.hasIPv6() {
return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddDNATRule(rule)
}
return m.router.AddDNATRule(rule)
}
@@ -434,9 +273,6 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.hasIPv6() && !m.router.hasRule(rule.ID()+dnatSuffix) {
return m.router6.DeleteDNATRule(rule)
}
return m.router.DeleteDNATRule(rule)
}
@@ -445,82 +281,39 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
m.mutex.Lock()
defer m.mutex.Unlock()
var v4Prefixes, v6Prefixes []netip.Prefix
for _, p := range prefixes {
if p.Addr().Is6() {
v6Prefixes = append(v6Prefixes, p)
} else {
v4Prefixes = append(v4Prefixes, p)
}
}
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
return err
}
if m.hasIPv6() && len(v6Prefixes) > 0 {
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
return fmt.Errorf("update v6 set: %w", err)
}
}
return nil
return m.router.UpdateSet(set, prefixes)
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if localAddr.Is6() {
if !m.hasIPv6() {
return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if localAddr.Is6() {
if !m.hasIPv6() {
return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if localAddr.Is6() {
if !m.hasIPv6() {
return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if localAddr.Is6() {
if !m.hasIPv6() {
return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
const (

View File

@@ -54,10 +54,8 @@ const (
snatSuffix = "_snat"
fwdSuffix = "_fwd"
// ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
ipv4TCPHeaderSize = 40
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
ipv6TCPHeaderSize = 60
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40
)
type ruleInfo struct {
@@ -88,7 +86,6 @@ type router struct {
wgIface iFaceMapper
legacyManagement bool
mtu uint16
v6 bool
stateManager *statemanager.Manager
ipFwdState *ipfwdstate.IPForwardingState
@@ -100,7 +97,6 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint1
rules: make(map[string][]string),
wgIface: wgIface,
mtu: mtu,
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
ipFwdState: ipfwdstate.NewIPForwardingState(),
}
@@ -190,11 +186,6 @@ func (r *router) AddRouteFiltering(
return ruleKey, nil
}
func (r *router) hasRule(id string) bool {
_, ok := r.rules[id]
return ok
}
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
ruleKey := rule.ID()
@@ -401,13 +392,9 @@ func (r *router) cleanUpDefaultForwardRules() error {
// Remove jump rules from built-in chains before deleting custom chains,
// otherwise the chain deletion fails with "device or resource busy".
if ok, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput); err != nil {
return fmt.Errorf("check chain %s: %w", chainNATOutput, err)
} else if ok {
jumpRule := []string{"-j", chainNATOutput}
if err := r.iptablesClient.Delete(tableNat, "OUTPUT", jumpRule...); err != nil {
log.Debugf("clean OUTPUT jump rule: %v", err)
}
jumpRule := []string{"-j", chainNATOutput}
if err := r.iptablesClient.Delete(tableNat, "OUTPUT", jumpRule...); err != nil {
log.Debugf("clean OUTPUT jump rule: %v", err)
}
for _, chainInfo := range []struct {
@@ -447,12 +434,6 @@ func (r *router) createContainers() error {
{chainRTRDR, tableNat},
{chainRTMSSCLAMP, tableMangle},
} {
// Fallback: clear chains that survived an unclean shutdown.
if ok, _ := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain); ok {
if err := r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
log.Warnf("clear stale chain %s in %s: %v", chainInfo.chain, chainInfo.table, err)
}
}
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
}
@@ -559,12 +540,9 @@ func (r *router) addPostroutingRules() error {
}
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
// TODO: Add IPv6 support
func (r *router) addMSSClampingRules() error {
overhead := uint16(ipv4TCPHeaderSize)
if r.v6 {
overhead = ipv6TCPHeaderSize
}
mss := r.mtu - overhead
mss := r.mtu - ipTCPHeaderMinSize
// Add jump rule from FORWARD chain in mangle table to our custom chain
jumpRule := []string{
@@ -749,13 +727,8 @@ func (r *router) updateState() {
currentState.Lock()
defer currentState.Unlock()
if r.v6 {
currentState.RouteRules6 = r.rules
currentState.RouteIPsetCounter6 = r.ipsetCounter
} else {
currentState.RouteRules = r.rules
currentState.RouteIPsetCounter = r.ipsetCounter
}
currentState.RouteRules = r.rules
currentState.RouteIPsetCounter = r.ipsetCounter
if err := r.stateManager.UpdateState(currentState); err != nil {
log.Errorf("failed to update state: %v", err)
@@ -883,7 +856,7 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
}
if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists {
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDOUT, fwdRule...); err != nil {
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, fwdRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err))
}
delete(r.rules, ruleKey+fwdSuffix)
@@ -910,7 +883,7 @@ func (r *router) genRouteRuleSpec(params routeFilteringRuleParams, sources []net
rule = append(rule, destExp...)
if params.Proto != firewall.ProtocolALL {
rule = append(rule, "-p", strings.ToLower(protoForFamily(params.Proto, r.v6)))
rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
rule = append(rule, applyPort("--sport", params.SPort)...)
rule = append(rule, applyPort("--dport", params.DPort)...)
}
@@ -927,12 +900,11 @@ func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []
}
if network.IsSet() {
name := r.ipsetName(network.Set.HashedName())
if _, err := r.ipsetCounter.Increment(name, prefixes); err != nil {
if _, err := r.ipsetCounter.Increment(network.Set.HashedName(), prefixes); err != nil {
return nil, fmt.Errorf("create or get ipset: %w", err)
}
return []string{"-m", "set", matchSet, name, direction}, nil
return []string{"-m", "set", matchSet, network.Set.HashedName(), direction}, nil
}
if network.IsPrefix() {
return []string{flag, network.Prefix.String()}, nil
@@ -943,23 +915,27 @@ func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []
}
func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
name := r.ipsetName(set.HashedName())
var merr *multierror.Error
for _, prefix := range prefixes {
if err := r.addPrefixToIPSet(name, prefix); err != nil {
// TODO: Implement IPv6 support
if prefix.Addr().Is6() {
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
continue
}
if err := r.addPrefixToIPSet(set.HashedName(), prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err))
}
}
if merr == nil {
log.Debugf("updated set %s with prefixes %v", name, prefixes)
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
}
return nberrors.FormatErrorOrNil(merr)
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if _, exists := r.rules[ruleID]; exists {
return nil
@@ -967,12 +943,12 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
dnatRule := []string{
"-i", r.wgIface.Name(),
"-p", strings.ToLower(protoForFamily(protocol, r.v6)),
"--dport", strconv.Itoa(int(originalPort)),
"-p", strings.ToLower(string(protocol)),
"--dport", strconv.Itoa(int(sourcePort)),
"-d", localAddr.String(),
"-m", "addrtype", "--dst-type", "LOCAL",
"-j", "DNAT",
"--to-destination", ":" + strconv.Itoa(int(translatedPort)),
"--to-destination", ":" + strconv.Itoa(int(targetPort)),
}
ruleInfo := ruleInfo{
@@ -991,8 +967,8 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if dnatRule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
@@ -1037,8 +1013,8 @@ func (r *router) ensureNATOutputChain() error {
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if _, exists := r.rules[ruleID]; exists {
return nil
@@ -1049,11 +1025,11 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol,
}
dnatRule := []string{
"-p", strings.ToLower(protoForFamily(protocol, localAddr.Is6())),
"--dport", strconv.Itoa(int(originalPort)),
"-p", strings.ToLower(string(protocol)),
"--dport", strconv.Itoa(int(sourcePort)),
"-d", localAddr.String(),
"-j", "DNAT",
"--to-destination", ":" + strconv.Itoa(int(translatedPort)),
"--to-destination", ":" + strconv.Itoa(int(targetPort)),
}
if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil {
@@ -1066,8 +1042,8 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol,
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if dnatRule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil {
@@ -1100,22 +1076,10 @@ func applyPort(flag string, port *firewall.Port) []string {
return []string{flag, strconv.Itoa(int(port.Values[0]))}
}
// ipsetName returns the ipset name, suffixed with "-v6" for the v6 router
// to avoid collisions since ipsets are global in the kernel.
func (r *router) ipsetName(name string) string {
if r.v6 {
return name + "-v6"
}
return name
}
func (r *router) createIPSet(name string) error {
opts := ipset.CreateOptions{
Replace: true,
}
if r.v6 {
opts.Family = ipset.FamilyIPV6
}
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
return fmt.Errorf("create ipset %s: %w", name, err)

View File

@@ -9,7 +9,6 @@ type Rule struct {
mangleSpecs []string
ip string
chain string
v6 bool
}
// GetRuleID returns the rule id

View File

@@ -4,8 +4,6 @@ import (
"fmt"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -34,12 +32,6 @@ type ShutdownState struct {
ACLEntries aclEntries `json:"acl_entries,omitempty"`
ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"`
// IPv6 counterparts
RouteRules6 routeRules `json:"route_rules_v6,omitempty"`
RouteIPsetCounter6 *ipsetCounter `json:"route_ipset_counter_v6,omitempty"`
ACLEntries6 aclEntries `json:"acl_entries_v6,omitempty"`
ACLIPsetStore6 *ipsetStore `json:"acl_ipset_store_v6,omitempty"`
}
func (s *ShutdownState) Name() string {
@@ -70,28 +62,6 @@ func (s *ShutdownState) Cleanup() error {
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
}
// Clean up v6 state even if the current run has no IPv6.
// The previous run may have left ip6tables rules behind.
if !ipt.hasIPv6() {
if err := ipt.createIPv6Components(s.InterfaceState, mtu); err != nil {
log.Warnf("failed to create v6 components for cleanup: %v", err)
}
}
if ipt.hasIPv6() {
if s.RouteRules6 != nil {
ipt.router6.rules = s.RouteRules6
}
if s.RouteIPsetCounter6 != nil {
ipt.router6.ipsetCounter.LoadData(s.RouteIPsetCounter6)
}
if s.ACLEntries6 != nil {
ipt.aclMgr6.entries = s.ACLEntries6
}
if s.ACLIPsetStore6 != nil {
ipt.aclMgr6.ipsetStore = s.ACLIPsetStore6
}
}
if err := ipt.Close(nil); err != nil {
return fmt.Errorf("reset iptables manager: %w", err)
}

View File

@@ -1,7 +1,6 @@
package manager
import (
"errors"
"fmt"
"net"
"net/netip"
@@ -12,10 +11,6 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// ErrIPv6NotInitialized is returned when an IPv6 address is passed to a firewall
// method but the IPv6 firewall components were not initialized.
var ErrIPv6NotInitialized = errors.New("IPv6 firewall not initialized")
const (
ForwardingFormatPrefix = "netbird-fwd-"
ForwardingFormat = "netbird-fwd-%s-%t"
@@ -169,16 +164,18 @@ type Manager interface {
UpdateSet(hash Set, prefixes []netip.Prefix) error
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
AddInboundDNAT(localAddr netip.Addr, protocol Protocol, originalPort, translatedPort uint16) error
AddInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// RemoveInboundDNAT removes inbound DNAT rule
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, originalPort, translatedPort uint16) error
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
AddOutputDNAT(localAddr netip.Addr, protocol Protocol, originalPort, translatedPort uint16) error
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only.
AddOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
RemoveOutputDNAT(localAddr netip.Addr, protocol Protocol, originalPort, translatedPort uint16) error
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only.
RemoveOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
// This prevents conntrack from interfering with WireGuard proxy communication.

View File

@@ -1,8 +1,6 @@
package manager
import (
"net/netip"
"github.com/netbirdio/netbird/route"
)
@@ -12,10 +10,6 @@ type RouterPair struct {
Destination Network
Masquerade bool
Inverse bool
// Dynamic indicates the route is domain-based. NAT rules for dynamic
// routes are duplicated to the v6 table so that resolved AAAA records
// are masqueraded correctly.
Dynamic bool
}
func GetInversePair(pair RouterPair) RouterPair {
@@ -26,17 +20,5 @@ func GetInversePair(pair RouterPair) RouterPair {
Destination: pair.Source,
Masquerade: pair.Masquerade,
Inverse: true,
Dynamic: pair.Dynamic,
}
}
// ToV6NatPair creates a v6 counterpart of a v4 NAT pair with `::/0` source
// and, for prefix destinations, `::/0` destination.
func ToV6NatPair(pair RouterPair) RouterPair {
v6 := pair
v6.Source = Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
if v6.Destination.IsPrefix() {
v6.Destination = Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
}
return v6
}

View File

@@ -33,12 +33,15 @@ const (
const flushError = "flush: %w"
var (
anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
)
type AclManager struct {
rConn *nftables.Conn
sConn *nftables.Conn
wgIface iFaceMapper
routingFwChainName string
af addrFamily
workTable *nftables.Table
chainInputRules *nftables.Chain
@@ -64,7 +67,6 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam
wgIface: wgIface,
workTable: table,
routingFwChainName: routingFwChainName,
af: familyForAddr(table.Family == nftables.TableFamilyIPv4),
ipsetStore: newIpsetStore(),
rules: make(map[string]*Rule),
@@ -143,7 +145,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
}
if _, ok := ips[r.ip.String()]; ok {
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: ipToBytes(r.ip, m.af)}})
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: r.ip.To4()}})
if err != nil {
log.Errorf("delete elements for set %q: %v", r.nftSet.Name, err)
}
@@ -252,11 +254,11 @@ func (m *AclManager) addIOFiltering(
expressions = append(expressions, &expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: m.af.protoOffset,
Offset: uint32(9),
Len: uint32(1),
})
protoData, err := m.af.protoNum(proto)
protoData, err := protoToInt(proto)
if err != nil {
return nil, fmt.Errorf("convert protocol to number: %v", err)
}
@@ -268,16 +270,19 @@ func (m *AclManager) addIOFiltering(
})
}
rawIP := ipToBytes(ip, m.af)
rawIP := ip.To4()
// check if rawIP contains zeroed IPv4 0.0.0.0 value
// in that case not add IP match expression into the rule definition
if slices.ContainsFunc(rawIP, func(v byte) bool { return v != 0 }) {
if !bytes.HasPrefix(anyIP, rawIP) {
// source address position
addrOffset := uint32(12)
expressions = append(expressions,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: m.af.srcAddrOffset,
Len: m.af.addrLen,
Offset: addrOffset,
Len: 4,
},
)
// add individual IP for match if no ipset defined
@@ -582,7 +587,7 @@ func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr
func (m *AclManager) addIpToSet(ipsetName string, ip net.IP) (*nftables.Set, error) {
ipset, err := m.rConn.GetSetByName(m.workTable, ipsetName)
rawIP := ipToBytes(ip, m.af)
rawIP := ip.To4()
if err != nil {
if ipset, err = m.createSet(m.workTable, ipsetName); err != nil {
return nil, fmt.Errorf("get set name: %v", err)
@@ -614,7 +619,7 @@ func (m *AclManager) createSet(table *nftables.Table, name string) (*nftables.Se
Name: name,
Table: table,
Dynamic: true,
KeyType: m.af.setKeyType,
KeyType: nftables.TypeIPAddr,
}
if err := m.rConn.AddSet(ipset, nil); err != nil {
@@ -702,12 +707,15 @@ func ifname(n string) []byte {
return b
}
// ipToBytes converts net.IP to the correct byte length for the address family.
func ipToBytes(ip net.IP, af addrFamily) []byte {
if af.addrLen == 4 {
return ip.To4()
func protoToInt(protocol firewall.Protocol) (uint8, error) {
switch protocol {
case firewall.ProtocolTCP:
return unix.IPPROTO_TCP, nil
case firewall.ProtocolUDP:
return unix.IPPROTO_UDP, nil
case firewall.ProtocolICMP:
return unix.IPPROTO_ICMP, nil
}
return ip.To16()
}
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
}

View File

@@ -1,81 +0,0 @@
package nftables
import (
"fmt"
"net"
"github.com/google/nftables"
"golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
var (
// afIPv4 defines IPv4 header layout and nftables types.
afIPv4 = addrFamily{
protoOffset: 9,
srcAddrOffset: 12,
dstAddrOffset: 16,
addrLen: net.IPv4len,
totalBits: 8 * net.IPv4len,
setKeyType: nftables.TypeIPAddr,
tableFamily: nftables.TableFamilyIPv4,
icmpProto: unix.IPPROTO_ICMP,
}
// afIPv6 defines IPv6 header layout and nftables types.
afIPv6 = addrFamily{
protoOffset: 6,
srcAddrOffset: 8,
dstAddrOffset: 24,
addrLen: net.IPv6len,
totalBits: 8 * net.IPv6len,
setKeyType: nftables.TypeIP6Addr,
tableFamily: nftables.TableFamilyIPv6,
icmpProto: unix.IPPROTO_ICMPV6,
}
)
// addrFamily holds protocol-specific constants for nftables expression building.
type addrFamily struct {
// protoOffset is the IP header offset for the protocol/next-header field (9 for v4, 6 for v6)
protoOffset uint32
// srcAddrOffset is the IP header offset for the source address (12 for v4, 8 for v6)
srcAddrOffset uint32
// dstAddrOffset is the IP header offset for the destination address (16 for v4, 24 for v6)
dstAddrOffset uint32
// addrLen is the byte length of addresses (4 for v4, 16 for v6)
addrLen uint32
// totalBits is the address size in bits (32 for v4, 128 for v6)
totalBits int
// setKeyType is the nftables set data type for addresses
setKeyType nftables.SetDatatype
// tableFamily is the nftables table family
tableFamily nftables.TableFamily
// icmpProto is the ICMP protocol number for this family (1 for v4, 58 for v6)
icmpProto uint8
}
// familyForAddr returns the address family for the given IP.
func familyForAddr(is4 bool) addrFamily {
if is4 {
return afIPv4
}
return afIPv6
}
// protoNum converts a firewall protocol to the IP protocol number,
// using the correct ICMP variant for the address family.
func (af addrFamily) protoNum(protocol firewall.Protocol) (uint8, error) {
switch protocol {
case firewall.ProtocolTCP:
return unix.IPPROTO_TCP, nil
case firewall.ProtocolUDP:
return unix.IPPROTO_UDP, nil
case firewall.ProtocolICMP:
return af.icmpProto, nil
case firewall.ProtocolALL:
return 0, nil
default:
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
}
}

View File

@@ -1,76 +0,0 @@
//go:build linux
package nftables
import (
"os"
"sync/atomic"
"testing"
"time"
"github.com/google/nftables"
"github.com/stretchr/testify/require"
)
// TestExternalChainMonitorRootIntegration verifies that adding a new chain
// in an external (non-netbird) filter table triggers the reconciler.
// Requires CAP_NET_ADMIN; skip otherwise.
func TestExternalChainMonitorRootIntegration(t *testing.T) {
if os.Geteuid() != 0 {
t.Skip("root required")
}
calls := make(chan struct{}, 8)
var count atomic.Int32
rec := &countingReconciler{calls: calls, count: &count}
m := newExternalChainMonitor(rec)
m.start()
t.Cleanup(m.stop)
// Give the netlink subscription a moment to register.
time.Sleep(200 * time.Millisecond)
conn := &nftables.Conn{}
table := conn.AddTable(&nftables.Table{
Name: "nbmon_integration_test",
Family: nftables.TableFamilyINet,
})
t.Cleanup(func() {
cleanup := &nftables.Conn{}
cleanup.DelTable(table)
_ = cleanup.Flush()
})
chain := conn.AddChain(&nftables.Chain{
Name: "filter_INPUT",
Table: table,
Hooknum: nftables.ChainHookInput,
Priority: nftables.ChainPriorityFilter,
Type: nftables.ChainTypeFilter,
})
_ = chain
require.NoError(t, conn.Flush(), "create external test chain")
select {
case <-calls:
// success
case <-time.After(3 * time.Second):
t.Fatalf("reconcile was not invoked after creating an external chain")
}
require.GreaterOrEqual(t, count.Load(), int32(1))
}
type countingReconciler struct {
calls chan struct{}
count *atomic.Int32
}
func (c *countingReconciler) reconcileExternalChains() error {
c.count.Add(1)
select {
case c.calls <- struct{}{}:
default:
}
return nil
}

View File

@@ -1,199 +0,0 @@
package nftables
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/google/nftables"
log "github.com/sirupsen/logrus"
)
const (
externalMonitorReconcileDelay = 500 * time.Millisecond
externalMonitorInitInterval = 5 * time.Second
externalMonitorMaxInterval = 5 * time.Minute
externalMonitorRandomization = 0.5
)
// externalChainReconciler re-applies passthrough accept rules to external
// nftables chains. Implementations must be safe to call from the monitor
// goroutine; the Manager locks its mutex internally.
type externalChainReconciler interface {
reconcileExternalChains() error
}
// externalChainMonitor watches nftables netlink events and triggers a
// reconcile when a new table or chain appears (e.g. after
// `firewall-cmd --reload`). Netlink errors trigger exponential-backoff
// reconnect.
type externalChainMonitor struct {
reconciler externalChainReconciler
mu sync.Mutex
cancel context.CancelFunc
done chan struct{}
}
func newExternalChainMonitor(r externalChainReconciler) *externalChainMonitor {
return &externalChainMonitor{reconciler: r}
}
func (m *externalChainMonitor) start() {
m.mu.Lock()
defer m.mu.Unlock()
if m.cancel != nil {
return
}
ctx, cancel := context.WithCancel(context.Background())
m.cancel = cancel
m.done = make(chan struct{})
go m.run(ctx)
}
func (m *externalChainMonitor) stop() {
m.mu.Lock()
cancel := m.cancel
done := m.done
m.cancel = nil
m.done = nil
m.mu.Unlock()
if cancel == nil {
return
}
cancel()
<-done
}
func (m *externalChainMonitor) run(ctx context.Context) {
defer close(m.done)
bo := &backoff.ExponentialBackOff{
InitialInterval: externalMonitorInitInterval,
RandomizationFactor: externalMonitorRandomization,
Multiplier: backoff.DefaultMultiplier,
MaxInterval: externalMonitorMaxInterval,
MaxElapsedTime: 0,
Clock: backoff.SystemClock,
}
bo.Reset()
for ctx.Err() == nil {
err := m.watch(ctx)
if ctx.Err() != nil {
return
}
delay := bo.NextBackOff()
log.Warnf("external chain monitor: %v, reconnecting in %s", err, delay)
select {
case <-ctx.Done():
return
case <-time.After(delay):
}
}
}
func (m *externalChainMonitor) watch(ctx context.Context) error {
events, closeMon, err := m.subscribe()
if err != nil {
return err
}
defer closeMon()
debounce := time.NewTimer(time.Hour)
if !debounce.Stop() {
<-debounce.C
}
defer debounce.Stop()
pending := false
for {
select {
case <-ctx.Done():
return nil
case <-debounce.C:
pending = false
m.reconcile()
case ev, ok := <-events:
if !ok {
return errors.New("monitor channel closed")
}
if ev.Error != nil {
return fmt.Errorf("monitor event: %w", ev.Error)
}
if !isRelevantMonitorEvent(ev) {
continue
}
resetDebounce(debounce, pending)
pending = true
}
}
}
func (m *externalChainMonitor) subscribe() (chan *nftables.MonitorEvent, func(), error) {
conn := &nftables.Conn{}
mon := nftables.NewMonitor(
nftables.WithMonitorAction(nftables.MonitorActionNew),
nftables.WithMonitorObject(nftables.MonitorObjectChains|nftables.MonitorObjectTables),
)
events, err := conn.AddMonitor(mon)
if err != nil {
return nil, nil, fmt.Errorf("add netlink monitor: %w", err)
}
return events, func() { _ = mon.Close() }, nil
}
// resetDebounce reschedules a pending debounce timer without leaking a stale
// fire on its channel. pending must reflect whether the timer is armed.
func resetDebounce(t *time.Timer, pending bool) {
if pending && !t.Stop() {
select {
case <-t.C:
default:
}
}
t.Reset(externalMonitorReconcileDelay)
}
func (m *externalChainMonitor) reconcile() {
if err := m.reconciler.reconcileExternalChains(); err != nil {
log.Warnf("reconcile external chain rules: %v", err)
}
}
// isRelevantMonitorEvent returns true for table/chain creation events on
// families we care about. The reconciler filters to actual external filter
// chains.
func isRelevantMonitorEvent(ev *nftables.MonitorEvent) bool {
switch ev.Type {
case nftables.MonitorEventTypeNewChain:
chain, ok := ev.Data.(*nftables.Chain)
if !ok || chain == nil || chain.Table == nil {
return false
}
return isMonitoredFamily(chain.Table.Family)
case nftables.MonitorEventTypeNewTable:
table, ok := ev.Data.(*nftables.Table)
if !ok || table == nil {
return false
}
return isMonitoredFamily(table.Family)
}
return false
}
func isMonitoredFamily(family nftables.TableFamily) bool {
switch family {
case nftables.TableFamilyIPv4, nftables.TableFamilyIPv6, nftables.TableFamilyINet:
return true
}
return false
}

View File

@@ -1,137 +0,0 @@
package nftables
import (
"testing"
"github.com/google/nftables"
"github.com/stretchr/testify/assert"
)
func TestIsMonitoredFamily(t *testing.T) {
tests := []struct {
family nftables.TableFamily
want bool
}{
{nftables.TableFamilyIPv4, true},
{nftables.TableFamilyIPv6, true},
{nftables.TableFamilyINet, true},
{nftables.TableFamilyARP, false},
{nftables.TableFamilyBridge, false},
{nftables.TableFamilyNetdev, false},
{nftables.TableFamilyUnspecified, false},
}
for _, tc := range tests {
assert.Equal(t, tc.want, isMonitoredFamily(tc.family), "family=%d", tc.family)
}
}
func TestIsRelevantMonitorEvent(t *testing.T) {
inetTable := &nftables.Table{Name: "firewalld", Family: nftables.TableFamilyINet}
ipTable := &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4}
arpTable := &nftables.Table{Name: "arp", Family: nftables.TableFamilyARP}
tests := []struct {
name string
ev *nftables.MonitorEvent
want bool
}{
{
name: "new chain in inet firewalld",
ev: &nftables.MonitorEvent{
Type: nftables.MonitorEventTypeNewChain,
Data: &nftables.Chain{Name: "filter_INPUT", Table: inetTable},
},
want: true,
},
{
name: "new chain in ip filter",
ev: &nftables.MonitorEvent{
Type: nftables.MonitorEventTypeNewChain,
Data: &nftables.Chain{Name: "INPUT", Table: ipTable},
},
want: true,
},
{
name: "new chain in unwatched arp family",
ev: &nftables.MonitorEvent{
Type: nftables.MonitorEventTypeNewChain,
Data: &nftables.Chain{Name: "x", Table: arpTable},
},
want: false,
},
{
name: "new table inet",
ev: &nftables.MonitorEvent{
Type: nftables.MonitorEventTypeNewTable,
Data: inetTable,
},
want: true,
},
{
name: "del chain (we only act on new)",
ev: &nftables.MonitorEvent{
Type: nftables.MonitorEventTypeDelChain,
Data: &nftables.Chain{Name: "filter_INPUT", Table: inetTable},
},
want: false,
},
{
name: "chain with nil table",
ev: &nftables.MonitorEvent{
Type: nftables.MonitorEventTypeNewChain,
Data: &nftables.Chain{Name: "x"},
},
want: false,
},
{
name: "nil data",
ev: &nftables.MonitorEvent{
Type: nftables.MonitorEventTypeNewChain,
Data: (*nftables.Chain)(nil),
},
want: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.want, isRelevantMonitorEvent(tc.ev))
})
}
}
// fakeReconciler records reconcile invocations for debounce tests.
type fakeReconciler struct {
calls chan struct{}
}
func (f *fakeReconciler) reconcileExternalChains() error {
f.calls <- struct{}{}
return nil
}
func TestExternalChainMonitorStopWithoutStart(t *testing.T) {
m := newExternalChainMonitor(&fakeReconciler{calls: make(chan struct{}, 1)})
// Must not panic or block.
m.stop()
}
func TestExternalChainMonitorDoubleStart(t *testing.T) {
// start() twice should be a no-op; stop() cleans up once.
// We avoid exercising the netlink watch loop here because it needs root.
m := newExternalChainMonitor(&fakeReconciler{calls: make(chan struct{}, 1)})
// Replace run with a stub that just waits for cancel, so start() stays
// deterministic without opening a netlink socket.
origDone := make(chan struct{})
m.done = origDone
m.cancel = func() { close(origDone) }
// Second start should be a no-op (cancel already set).
m.start()
assert.NotNil(t, m.cancel)
m.stop()
assert.Nil(t, m.cancel)
assert.Nil(t, m.done)
}

View File

@@ -11,11 +11,9 @@ import (
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
@@ -51,17 +49,10 @@ type Manager struct {
rConn *nftables.Conn
wgIface iFaceMapper
router *router
aclManager *AclManager
// IPv6 counterparts, nil when no v6 overlay
router6 *router
aclManager6 *AclManager
router *router
aclManager *AclManager
notrackOutputChain *nftables.Chain
notrackPreroutingChain *nftables.Chain
extMonitor *externalChainMonitor
}
// Create nftables firewall manager
@@ -71,8 +62,7 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
wgIface: wgIface,
}
tableName := getTableName()
workTable := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}
workTable := &nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4}
var err error
m.router, err = newRouter(workTable, wgIface, mtu)
@@ -85,137 +75,35 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
return nil, fmt.Errorf("create acl manager: %w", err)
}
if wgIface.Address().HasIPv6() {
if err := m.createIPv6Components(tableName, wgIface, mtu); err != nil {
return nil, fmt.Errorf("create IPv6 firewall: %w", err)
}
}
m.extMonitor = newExternalChainMonitor(m)
return m, nil
}
func (m *Manager) createIPv6Components(tableName string, wgIface iFaceMapper, mtu uint16) error {
workTable6 := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv6}
var err error
m.router6, err = newRouter(workTable6, wgIface, mtu)
if err != nil {
return fmt.Errorf("create v6 router: %w", err)
}
// Share the same IP forwarding state with the v4 router, since
// EnableIPForwarding controls both v4 and v6 sysctls.
m.router6.ipFwdState = m.router.ipFwdState
m.aclManager6, err = newAclManager(workTable6, wgIface, chainNameRoutingFw)
if err != nil {
return fmt.Errorf("create v6 acl manager: %w", err)
}
return nil
}
// hasIPv6 reports whether the manager has IPv6 components initialized.
func (m *Manager) hasIPv6() bool {
return m.router6 != nil
}
func (m *Manager) initIPv6() error {
workTable6, err := m.createWorkTableFamily(nftables.TableFamilyIPv6)
if err != nil {
return fmt.Errorf("create v6 work table: %w", err)
}
if err := m.router6.init(workTable6); err != nil {
return fmt.Errorf("v6 router init: %w", err)
}
if err := m.aclManager6.init(workTable6); err != nil {
return fmt.Errorf("v6 acl manager init: %w", err)
}
return nil
}
// Init nftables firewall manager
func (m *Manager) Init(stateManager *statemanager.Manager) error {
if err := m.initFirewall(); err != nil {
return err
}
m.persistState(stateManager)
// Start after initFirewall has installed the baseline external-chain
// accept rules. start() is idempotent across Init/Close/Init cycles.
m.extMonitor.start()
return nil
}
// reconcileExternalChains re-applies passthrough accept rules to external
// filter chains for both IPv4 and IPv6 routers. Called by the monitor when
// tables or chains appear (e.g. after firewalld reloads).
func (m *Manager) reconcileExternalChains() error {
m.mutex.Lock()
defer m.mutex.Unlock()
var merr *multierror.Error
if m.router != nil {
if err := m.router.acceptExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("v4: %w", err))
}
}
if m.hasIPv6() {
if err := m.router6.acceptExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("v6: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (m *Manager) initFirewall() (err error) {
workTable, err := m.createWorkTable()
if err != nil {
return fmt.Errorf("create work table: %w", err)
}
defer func() {
if err != nil {
m.rollbackInit()
}
}()
if err := m.router.init(workTable); err != nil {
return fmt.Errorf("router init: %w", err)
}
if err := m.aclManager.init(workTable); err != nil {
// TODO: cleanup router
return fmt.Errorf("acl manager init: %w", err)
}
if m.hasIPv6() {
if err := m.initIPv6(); err != nil {
// Peer has a v6 address: v6 firewall MUST work or we risk fail-open.
return fmt.Errorf("init IPv6 firewall (required because peer has IPv6 address): %w", err)
}
}
if err := m.initNoTrackChains(workTable); err != nil {
log.Warnf("raw priority chains not available, notrack rules will be disabled: %v", err)
}
return nil
}
// persistState saves the current interface state for potential recreation on restart.
// Unlike iptables, which requires tracking individual rules, nftables maintains
// a known state (our netbird table plus a few static rules). This allows for easy
// cleanup using Close() without needing to store specific rules.
func (m *Manager) persistState(stateManager *statemanager.Manager) {
stateManager.RegisterState(&ShutdownState{})
// We only need to record minimal interface state for potential recreation.
// Unlike iptables, which requires tracking individual rules, nftables maintains
// a known state (our netbird table plus a few static rules). This allows for easy
// cleanup using Close() without needing to store specific rules.
if err := stateManager.UpdateState(&ShutdownState{
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
@@ -226,29 +114,14 @@ func (m *Manager) persistState(stateManager *statemanager.Manager) {
log.Errorf("failed to update state: %v", err)
}
// persist early
go func() {
if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
}()
}
// rollbackInit performs best-effort cleanup of already-initialized state when Init fails partway through.
func (m *Manager) rollbackInit() {
if err := m.router.Reset(); err != nil {
log.Warnf("rollback router: %v", err)
}
if m.hasIPv6() {
if err := m.router6.Reset(); err != nil {
log.Warnf("rollback v6 router: %v", err)
}
}
if err := m.cleanupNetbirdTables(); err != nil {
log.Warnf("cleanup tables: %v", err)
}
if err := m.rConn.Flush(); err != nil {
log.Warnf("flush: %v", err)
}
return nil
}
// AddPeerFiltering rule to the firewall
@@ -267,14 +140,12 @@ func (m *Manager) AddPeerFiltering(
m.mutex.Lock()
defer m.mutex.Unlock()
if ip.To4() != nil {
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
rawIP := ip.To4()
if rawIP == nil {
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
}
if !m.hasIPv6() {
return nil, fmt.Errorf("add peer filtering for %s: %w", ip, firewall.ErrIPv6NotInitialized)
}
return m.aclManager6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
}
func (m *Manager) AddRouteFiltering(
@@ -288,11 +159,8 @@ func (m *Manager) AddRouteFiltering(
m.mutex.Lock()
defer m.mutex.Unlock()
if isIPv6RouteRule(sources, destination) {
if !m.hasIPv6() {
return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
}
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
@@ -303,66 +171,15 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.hasIPv6() && isIPv6Rule(rule) {
return m.aclManager6.DeletePeerRule(rule)
}
return m.aclManager.DeletePeerRule(rule)
}
func isIPv6Rule(rule firewall.Rule) bool {
r, ok := rule.(*Rule)
return ok && r.nftRule != nil && r.nftRule.Table != nil && r.nftRule.Table.Family == nftables.TableFamilyIPv6
}
// isIPv6RouteRule determines whether a route rule belongs to the v6 table.
// For static routes, the destination prefix determines the family. For dynamic
// routes (DomainSet), the sources determine the family since management
// duplicates dynamic rules per family.
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
if destination.IsPrefix() {
return destination.Prefix.Addr().Is6()
}
return len(sources) > 0 && sources[0].Addr().Is6()
}
// DeleteRouteRule deletes a routing rule. Route rules live in exactly one
// router; the cached maps are normally authoritative, so the kernel is only
// consulted when neither map knows about the rule.
// DeleteRouteRule deletes a routing rule
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
id := rule.ID()
r, err := m.routerForRuleID(id, (*router).hasRule)
if err != nil {
return err
}
return r.DeleteRouteRule(rule)
}
// routerForRuleID picks the router holding the rule with the given id, using
// the supplied lookup. If the cached maps disagree (or both miss), it refreshes
// from the kernel once and re-checks before falling back to the v4 router.
func (m *Manager) routerForRuleID(id string, has func(*router, string) bool) (*router, error) {
if has(m.router, id) {
return m.router, nil
}
if m.hasIPv6() && has(m.router6, id) {
return m.router6, nil
}
if !m.hasIPv6() {
return m.router, nil
}
if err := m.router.refreshRulesMap(); err != nil {
return nil, fmt.Errorf("refresh v4 rules: %w", err)
}
if err := m.router6.refreshRulesMap(); err != nil {
return nil, fmt.Errorf("refresh v6 rules: %w", err)
}
if has(m.router6, id) && !has(m.router, id) {
return m.router6, nil
}
return m.router, nil
return m.router.DeleteRouteRule(rule)
}
func (m *Manager) IsServerRouteSupported() bool {
@@ -377,70 +194,19 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
if !m.hasIPv6() {
return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddNatRule(pair)
}
if err := m.router.AddNatRule(pair); err != nil {
return err
}
// Dynamic routes need NAT in both tables since resolved IPs can be
// either v4 or v6. This covers both DomainSet (modern) and the legacy
// wildcard 0.0.0.0/0 destination where the client resolves DNS.
// On v6 failure we keep the v4 NAT rule rather than rolling back: half
// connectivity is better than none, and RemoveNatRule is content-keyed
// so the eventual cleanup still works.
if m.hasIPv6() && pair.Dynamic {
v6Pair := firewall.ToV6NatPair(pair)
if err := m.router6.AddNatRule(v6Pair); err != nil {
return fmt.Errorf("add v6 NAT rule: %w", err)
}
}
return nil
return m.router.AddNatRule(pair)
}
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
if !m.hasIPv6() {
return nil
}
return m.router6.RemoveNatRule(pair)
}
var merr *multierror.Error
if err := m.router.RemoveNatRule(pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err))
}
if m.hasIPv6() && pair.Dynamic {
v6Pair := firewall.ToV6NatPair(pair)
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
return m.router.RemoveNatRule(pair)
}
// AllowNetbird allows netbird interface traffic.
// This is called when USPFilter wraps the native firewall, adding blanket accept
// rules so that packet filtering is handled in userspace instead of by netfilter.
//
// TODO: In USP mode this only adds ACCEPT to the netbird table's own chains,
// which doesn't override DROP rules in external tables (e.g. firewalld).
// Should add passthrough rules to external chains (like the native mode router's
// addExternalChainsRules does) for both the netbird table family and inet tables.
// The netbird table itself is fine (routing chains already exist there), but
// non-netbird tables with INPUT/FORWARD hooks can still DROP our WG traffic.
func (m *Manager) AllowNetbird() error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -448,11 +214,6 @@ func (m *Manager) AllowNetbird() error {
if err := m.aclManager.createDefaultAllowRules(); err != nil {
return fmt.Errorf("create default allow rules: %w", err)
}
if m.hasIPv6() {
if err := m.aclManager6.createDefaultAllowRules(); err != nil {
return fmt.Errorf("create v6 default allow rules: %w", err)
}
}
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf("flush allow input netbird rules: %w", err)
}
@@ -466,47 +227,31 @@ func (m *Manager) AllowNetbird() error {
// SetLegacyManagement sets the route manager to use legacy management
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
return err
}
if m.hasIPv6() {
return firewall.SetLegacyManagement(m.router6, isLegacy)
}
return nil
return firewall.SetLegacyManagement(m.router, isLegacy)
}
// Close closes the firewall manager
func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.extMonitor.stop()
m.mutex.Lock()
defer m.mutex.Unlock()
var merr *multierror.Error
if err := m.router.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset router: %v", err))
}
if m.hasIPv6() {
if err := m.router6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %v", err))
}
return fmt.Errorf("reset router: %v", err)
}
if err := m.cleanupNetbirdTables(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("cleanup netbird tables: %v", err))
return fmt.Errorf("cleanup netbird tables: %v", err)
}
if err := m.rConn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
return fmt.Errorf(flushError, err)
}
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete state: %v", err))
return fmt.Errorf("delete state: %v", err)
}
return nberrors.FormatErrorOrNil(merr)
return nil
}
func (m *Manager) cleanupNetbirdTables() error {
@@ -555,12 +300,6 @@ func (m *Manager) Flush() error {
return err
}
if m.hasIPv6() {
if err := m.aclManager6.Flush(); err != nil {
return fmt.Errorf("flush v6 acl: %w", err)
}
}
if err := m.refreshNoTrackChains(); err != nil {
log.Errorf("failed to refresh notrack chains: %v", err)
}
@@ -573,12 +312,6 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
m.mutex.Lock()
defer m.mutex.Unlock()
if rule.TranslatedAddress.Is6() {
if !m.hasIPv6() {
return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddDNATRule(rule)
}
return m.router.AddDNATRule(rule)
}
@@ -587,11 +320,7 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
r, err := m.routerForRuleID(rule.ID(), (*router).hasDNATRule)
if err != nil {
return err
}
return r.DeleteDNATRule(rule)
return m.router.DeleteDNATRule(rule)
}
// UpdateSet updates the set with the given prefixes
@@ -599,82 +328,39 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
m.mutex.Lock()
defer m.mutex.Unlock()
var v4Prefixes, v6Prefixes []netip.Prefix
for _, p := range prefixes {
if p.Addr().Is6() {
v6Prefixes = append(v6Prefixes, p)
} else {
v4Prefixes = append(v4Prefixes, p)
}
}
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
return err
}
if m.hasIPv6() && len(v6Prefixes) > 0 {
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
return fmt.Errorf("update v6 set: %w", err)
}
}
return nil
return m.router.UpdateSet(set, prefixes)
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if localAddr.Is6() {
if !m.hasIPv6() {
return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if localAddr.Is6() {
if !m.hasIPv6() {
return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if localAddr.Is6() {
if !m.hasIPv6() {
return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if localAddr.Is6() {
if !m.hasIPv6() {
return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
const (
@@ -848,11 +534,7 @@ func (m *Manager) refreshNoTrackChains() error {
}
func (m *Manager) createWorkTable() (*nftables.Table, error) {
return m.createWorkTableFamily(nftables.TableFamilyIPv4)
}
func (m *Manager) createWorkTableFamily(family nftables.TableFamily) (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(family)
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return nil, fmt.Errorf("list of tables: %w", err)
}
@@ -864,7 +546,7 @@ func (m *Manager) createWorkTableFamily(family nftables.TableFamily) (*nftables.
}
}
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: family})
table := m.rConn.AddTable(&nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4})
err = m.rConn.Flush()
return table, err
}

View File

@@ -383,138 +383,10 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
err = manager.AddNatRule(pair)
require.NoError(t, err, "failed to add NAT rule")
dnatRule, err := manager.AddDNATRule(fw.ForwardRule{
Protocol: fw.ProtocolTCP,
DestinationPort: fw.Port{Values: []uint16{8080}},
TranslatedAddress: netip.MustParseAddr("100.96.0.2"),
TranslatedPort: fw.Port{Values: []uint16{80}},
})
require.NoError(t, err, "failed to add DNAT rule")
t.Cleanup(func() {
require.NoError(t, manager.DeleteDNATRule(dnatRule), "failed to delete DNAT rule")
})
stdout, stderr = runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
}
func TestNftablesManagerIPv6CompatibilityWithIp6tables(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
for _, bin := range []string{"ip6tables", "ip6tables-save", "iptables-save"} {
if _, err := exec.LookPath(bin); err != nil {
t.Skipf("%s not available on this system: %v", bin, err)
}
}
// Seed ip6 tables in the nft backend. Docker may not create them.
seedIp6tables(t)
ifaceMockV6 := &iFaceMock{
NameFunc: func() string { return "wt-test" },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.96.0.1"),
Network: netip.MustParsePrefix("100.96.0.0/16"),
IPv6: netip.MustParseAddr("fd00::1"),
IPv6Net: netip.MustParsePrefix("fd00::/64"),
}
},
}
manager, err := Create(ifaceMockV6, iface.DefaultMTU)
require.NoError(t, err, "create manager")
require.NoError(t, manager.Init(nil))
t.Cleanup(func() {
require.NoError(t, manager.Close(nil), "close manager")
stdout, stderr := runIp6tablesSave(t)
verifyIp6tablesOutput(t, stdout, stderr)
})
ip := netip.MustParseAddr("fd00::2")
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err, "add v6 peer filtering rule")
_, err = manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("fd00:1::/64")},
fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")},
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err, "add v6 route filtering rule")
err = manager.AddNatRule(fw.RouterPair{
Source: fw.Network{Prefix: netip.MustParsePrefix("fd00::/64")},
Destination: fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")},
Masquerade: true,
})
require.NoError(t, err, "add v6 NAT rule")
dnatRule, err := manager.AddDNATRule(fw.ForwardRule{
Protocol: fw.ProtocolTCP,
DestinationPort: fw.Port{Values: []uint16{8080}},
TranslatedAddress: netip.MustParseAddr("fd00::2"),
TranslatedPort: fw.Port{Values: []uint16{80}},
})
require.NoError(t, err, "add v6 DNAT rule")
t.Cleanup(func() {
require.NoError(t, manager.DeleteDNATRule(dnatRule), "delete v6 DNAT rule")
})
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
stdout, stderr = runIp6tablesSave(t)
verifyIp6tablesOutput(t, stdout, stderr)
}
func seedIp6tables(t *testing.T) {
t.Helper()
for _, tc := range []struct{ table, chain string }{
{"filter", "FORWARD"},
{"nat", "POSTROUTING"},
{"mangle", "FORWARD"},
} {
add := exec.Command("ip6tables", "-t", tc.table, "-A", tc.chain, "-j", "ACCEPT")
require.NoError(t, add.Run(), "seed ip6tables -t %s", tc.table)
del := exec.Command("ip6tables", "-t", tc.table, "-D", tc.chain, "-j", "ACCEPT")
require.NoError(t, del.Run(), "unseed ip6tables -t %s", tc.table)
}
}
func runIp6tablesSave(t *testing.T) (string, string) {
t.Helper()
var stdout, stderr bytes.Buffer
cmd := exec.Command("ip6tables-save")
cmd.Stdout = &stdout
cmd.Stderr = &stderr
require.NoError(t, cmd.Run(), "ip6tables-save failed")
return stdout.String(), stderr.String()
}
func verifyIp6tablesOutput(t *testing.T, stdout, stderr string) {
t.Helper()
for _, msg := range []string{
"Table `nat' is incompatible",
"Table `mangle' is incompatible",
"Table `filter' is incompatible",
} {
require.NotContains(t, stdout, msg,
"ip6tables-save stdout reports incompatibility: %s", stdout)
require.NotContains(t, stderr, msg,
"ip6tables-save stderr reports incompatibility: %s", stderr)
}
}
func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")

View File

@@ -50,10 +50,8 @@ const (
dnatSuffix = "_dnat"
snatSuffix = "_snat"
// ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
ipv4TCPHeaderSize = 40
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
ipv6TCPHeaderSize = 60
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40
// maxPrefixesSet 1638 prefixes start to fail, taking some margin
maxPrefixesSet = 1500
@@ -78,7 +76,6 @@ type router struct {
rules map[string]*nftables.Rule
ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set]
af addrFamily
wgIface iFaceMapper
ipFwdState *ipfwdstate.IPForwardingState
legacyManagement bool
@@ -91,7 +88,6 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou
workTable: workTable,
chains: make(map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
af: familyForAddr(workTable.Family == nftables.TableFamilyIPv4),
wgIface: wgIface,
ipFwdState: ipfwdstate.NewIPForwardingState(),
mtu: mtu,
@@ -154,7 +150,7 @@ func (r *router) Reset() error {
func (r *router) removeNatPreroutingRules() error {
table := &nftables.Table{
Name: tableNat,
Family: r.af.tableFamily,
Family: nftables.TableFamilyIPv4,
}
chain := &nftables.Chain{
Name: chainNameNatPrerouting,
@@ -187,7 +183,7 @@ func (r *router) removeNatPreroutingRules() error {
}
func (r *router) loadFilterTable() (*nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(r.af.tableFamily)
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return nil, fmt.Errorf("list tables: %w", err)
}
@@ -423,7 +419,7 @@ func (r *router) AddRouteFiltering(
// Handle protocol
if proto != firewall.ProtocolALL {
protoNum, err := r.af.protoNum(proto)
protoNum, err := protoToInt(proto)
if err != nil {
return nil, fmt.Errorf("convert protocol to number: %w", err)
}
@@ -483,24 +479,7 @@ func (r *router) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bo
return nil, fmt.Errorf("create or get ipset: %w", err)
}
return r.getIpSetExprs(ref, isSource)
}
func (r *router) iptablesProto() iptables.Protocol {
if r.af.tableFamily == nftables.TableFamilyIPv6 {
return iptables.ProtocolIPv6
}
return iptables.ProtocolIPv4
}
func (r *router) hasRule(id string) bool {
_, ok := r.rules[id]
return ok
}
func (r *router) hasDNATRule(id string) bool {
_, ok := r.rules[id+dnatSuffix]
return ok
return getIpSetExprs(ref, isSource)
}
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
@@ -549,10 +528,10 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err
Table: r.workTable,
// required for prefixes
Interval: true,
KeyType: r.af.setKeyType,
KeyType: nftables.TypeIPAddr,
}
elements := r.convertPrefixesToSet(prefixes)
elements := convertPrefixesToSet(prefixes)
nElements := len(elements)
maxElements := maxPrefixesSet * 2
@@ -585,17 +564,23 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err
return nfset, nil
}
func (r *router) convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
var elements []nftables.SetElement
for _, prefix := range prefixes {
// TODO: Implement IPv6 support
if prefix.Addr().Is6() {
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
continue
}
// nftables needs half-open intervals [firstIP, lastIP) for prefixes
// e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc
firstIP := prefix.Addr()
lastIP := calculateLastIP(prefix).Next()
elements = append(elements,
// the nft tool also adds a zero-address IntervalEnd element, see https://github.com/google/nftables/issues/247
// nftables.SetElement{Key: make([]byte, r.af.addrLen), IntervalEnd: true},
// the nft tool also adds a line like this, see https://github.com/google/nftables/issues/247
// nftables.SetElement{Key: []byte{0, 0, 0, 0}, IntervalEnd: true},
nftables.SetElement{Key: firstIP.AsSlice()},
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
)
@@ -605,20 +590,10 @@ func (r *router) convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetEle
// calculateLastIP determines the last IP in a given prefix.
func calculateLastIP(prefix netip.Prefix) netip.Addr {
masked := prefix.Masked()
if masked.Addr().Is4() {
hostMask := ^uint32(0) >> masked.Bits()
lastIP := uint32FromNetipAddr(masked.Addr()) | hostMask
return netip.AddrFrom4(uint32ToBytes(lastIP))
}
hostMask := ^uint32(0) >> prefix.Masked().Bits()
lastIP := uint32FromNetipAddr(prefix.Addr()) | hostMask
// IPv6: set host bits to all 1s
b := masked.Addr().As16()
bits := masked.Bits()
for i := bits; i < 128; i++ {
b[i/8] |= 1 << (7 - i%8)
}
return netip.AddrFrom16(b)
return netip.AddrFrom4(uint32ToBytes(lastIP))
}
// Utility function to convert netip.Addr to uint32.
@@ -870,16 +845,9 @@ func (r *router) addPostroutingRules() {
}
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
// TODO: Add IPv6 support
func (r *router) addMSSClampingRules() error {
overhead := uint16(ipv4TCPHeaderSize)
if r.af.tableFamily == nftables.TableFamilyIPv6 {
overhead = ipv6TCPHeaderSize
}
if r.mtu <= overhead {
log.Debugf("MTU %d too small for MSS clamping (overhead %d), skipping", r.mtu, overhead)
return nil
}
mss := r.mtu - overhead
mss := r.mtu - ipTCPHeaderMinSize
exprsOut := []expr.Any{
&expr.Meta{
@@ -1086,22 +1054,17 @@ func (r *router) acceptFilterTableRules() error {
log.Debugf("Used %s to add accept forward and input rules", fw)
}()
// Try iptables first and fallback to nftables if iptables is not available.
// Use the correct protocol (iptables vs ip6tables) for the address family.
ipt, err := iptables.NewWithProtocol(r.iptablesProto())
// Try iptables first and fallback to nftables if iptables is not available
ipt, err := iptables.New()
if err != nil {
// iptables is not available but the filter table exists
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
fw = "nftables"
return r.acceptFilterRulesNftables(r.filterTable)
}
if err := r.acceptFilterRulesIptables(ipt); err != nil {
log.Warnf("iptables failed (table may be incompatible), falling back to nftables: %v", err)
fw = "nftables"
return r.acceptFilterRulesNftables(r.filterTable)
}
return nil
return r.acceptFilterRulesIptables(ipt)
}
func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
@@ -1172,122 +1135,83 @@ func (r *router) acceptExternalChainsRules() error {
}
intf := ifname(r.wgIface.Name())
for _, chain := range chains {
r.applyExternalChainAccept(chain, intf)
if chain.Hooknum == nil {
log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name)
continue
}
log.Debugf("adding accept rules to external %s chain: %s %s/%s",
hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name)
switch *chain.Hooknum {
case *nftables.ChainHookForward:
r.insertForwardAcceptRules(chain, intf)
case *nftables.ChainHookInput:
r.insertInputAcceptRule(chain, intf)
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush external chain rules: %w", err)
}
return nil
}
func (r *router) applyExternalChainAccept(chain *nftables.Chain, intf []byte) {
if chain.Hooknum == nil {
log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name)
return
}
log.Debugf("adding accept rules to external %s chain: %s %s/%s",
hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name)
switch *chain.Hooknum {
case *nftables.ChainHookForward:
r.insertForwardAcceptRules(chain, intf)
case *nftables.ChainHookInput:
r.insertInputAcceptRule(chain, intf)
}
}
func (r *router) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) {
existing, err := r.existingNetbirdRulesInChain(chain)
if err != nil {
log.Warnf("skip forward accept rules in %s/%s: %v", chain.Table.Name, chain.Name, err)
return
}
r.insertForwardIifRule(chain, intf, existing)
r.insertForwardOifEstablishedRule(chain, intf, existing)
}
func (r *router) insertForwardIifRule(chain *nftables.Chain, intf []byte, existing map[string]bool) {
if existing[userDataAcceptForwardRuleIif] {
return
}
r.conn.InsertRule(&nftables.Rule{
iifRule := &nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: intf,
},
&expr.Counter{},
&expr.Verdict{Kind: expr.VerdictAccept},
},
UserData: []byte(userDataAcceptForwardRuleIif),
})
}
}
r.conn.InsertRule(iifRule)
func (r *router) insertForwardOifEstablishedRule(chain *nftables.Chain, intf []byte, existing map[string]bool) {
if existing[userDataAcceptForwardRuleOif] {
return
}
exprs := []expr.Any{
oifExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: intf,
},
}
r.conn.InsertRule(&nftables.Rule{
oifRule := &nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: append(exprs, getEstablishedExprs(2)...),
Exprs: append(oifExprs, getEstablishedExprs(2)...),
UserData: []byte(userDataAcceptForwardRuleOif),
})
}
r.conn.InsertRule(oifRule)
}
func (r *router) insertInputAcceptRule(chain *nftables.Chain, intf []byte) {
existing, err := r.existingNetbirdRulesInChain(chain)
if err != nil {
log.Warnf("skip input accept rule in %s/%s: %v", chain.Table.Name, chain.Name, err)
return
}
if existing[userDataAcceptInputRule] {
return
}
r.conn.InsertRule(&nftables.Rule{
inputRule := &nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: intf,
},
&expr.Counter{},
&expr.Verdict{Kind: expr.VerdictAccept},
},
UserData: []byte(userDataAcceptInputRule),
})
}
// existingNetbirdRulesInChain returns the set of netbird-owned UserData tags present in a chain; callers must bail on error since InsertRule is additive.
func (r *router) existingNetbirdRulesInChain(chain *nftables.Chain) (map[string]bool, error) {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
return nil, fmt.Errorf("list rules: %w", err)
}
present := map[string]bool{}
for _, rule := range rules {
if !isNetbirdAcceptRuleTag(rule.UserData) {
continue
}
present[string(rule.UserData)] = true
}
return present, nil
}
func isNetbirdAcceptRuleTag(userData []byte) bool {
switch string(userData) {
case userDataAcceptForwardRuleIif,
userDataAcceptForwardRuleOif,
userDataAcceptInputRule:
return true
}
return false
r.conn.InsertRule(inputRule)
}
func (r *router) removeAcceptFilterRules() error {
@@ -1309,17 +1233,13 @@ func (r *router) removeFilterTableRules() error {
return nil
}
ipt, err := iptables.NewWithProtocol(r.iptablesProto())
ipt, err := iptables.New()
if err != nil {
log.Debugf("iptables not available, using nftables to remove filter rules: %v", err)
return r.removeAcceptRulesFromTable(r.filterTable)
}
if err := r.removeAcceptFilterRulesIptables(ipt); err != nil {
log.Debugf("iptables removal failed (table may be incompatible), falling back to nftables: %v", err)
return r.removeAcceptRulesFromTable(r.filterTable)
}
return nil
return r.removeAcceptFilterRulesIptables(ipt)
}
func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error {
@@ -1386,7 +1306,7 @@ func (r *router) removeExternalChainsRules() error {
func (r *router) findExternalChains() []*nftables.Chain {
var chains []*nftables.Chain
families := []nftables.TableFamily{r.af.tableFamily, nftables.TableFamilyINet}
families := []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyINet}
for _, family := range families {
allChains, err := r.conn.ListChainsOfTableFamily(family)
@@ -1417,8 +1337,8 @@ func (r *router) isExternalChain(chain *nftables.Chain) bool {
return false
}
// Skip iptables/ip6tables-managed tables (adding nft-native rules breaks iptables-save compat)
if (chain.Table.Family == nftables.TableFamilyIPv4 || chain.Table.Family == nftables.TableFamilyIPv6) && isIptablesTable(chain.Table.Name) {
// Skip all iptables-managed tables in the ip family
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
return false
}
@@ -1559,7 +1479,7 @@ func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
return rule, nil
}
protoNum, err := r.af.protoNum(rule.Protocol)
protoNum, err := protoToInt(rule.Protocol)
if err != nil {
return nil, fmt.Errorf("convert protocol to number: %w", err)
}
@@ -1622,7 +1542,7 @@ func (r *router) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, rule
dnatExprs = append(dnatExprs,
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(r.af.tableFamily),
Family: uint32(nftables.TableFamilyIPv4),
RegAddrMin: 1,
RegProtoMin: regProtoMin,
RegProtoMax: regProtoMax,
@@ -1715,15 +1635,14 @@ func (r *router) addXTablesRedirect(dnatExprs []expr.Any, ruleKey string, rule f
},
)
natTable := &nftables.Table{
Name: tableNat,
Family: r.af.tableFamily,
}
dnatRule := &nftables.Rule{
Table: natTable,
Table: &nftables.Table{
Name: tableNat,
Family: nftables.TableFamilyIPv4,
},
Chain: &nftables.Chain{
Name: chainNameNatPrerouting,
Table: natTable,
Table: r.filterTable,
Type: nftables.ChainTypeNAT,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityNATDest,
@@ -1754,8 +1673,8 @@ func (r *router) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleKey
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: r.af.dstAddrOffset,
Len: r.af.addrLen,
Offset: 16,
Len: 4,
},
&expr.Cmp{
Op: expr.CmpOpEq,
@@ -1833,7 +1752,7 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return fmt.Errorf("get set %s: %w", set.HashedName(), err)
}
elements := r.convertPrefixesToSet(prefixes)
elements := convertPrefixesToSet(prefixes)
if err := r.conn.SetAddElements(nfset, elements); err != nil {
return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err)
}
@@ -1848,14 +1767,14 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if _, exists := r.rules[ruleID]; exists {
return nil
}
protoNum, err := r.af.protoNum(protocol)
protoNum, err := protoToInt(protocol)
if err != nil {
return fmt.Errorf("convert protocol to number: %w", err)
}
@@ -1882,15 +1801,11 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 3,
Data: binaryutil.BigEndian.PutUint16(originalPort),
Data: binaryutil.BigEndian.PutUint16(sourcePort),
},
}
bits := 32
if localAddr.Is6() {
bits = 128
}
exprs = append(exprs, r.applyPrefix(netip.PrefixFrom(localAddr, bits), false)...)
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
exprs = append(exprs,
&expr.Immediate{
@@ -1899,11 +1814,11 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(translatedPort),
Data: binaryutil.BigEndian.PutUint16(targetPort),
},
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(r.af.tableFamily),
Family: uint32(nftables.TableFamilyIPv4),
RegAddrMin: 1,
RegProtoMin: 2,
RegProtoMax: 0,
@@ -1928,12 +1843,12 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
rule, exists := r.rules[ruleID]
if !exists {
@@ -1979,8 +1894,8 @@ func (r *router) ensureNATOutputChain() error {
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if _, exists := r.rules[ruleID]; exists {
return nil
@@ -1990,7 +1905,7 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol,
return err
}
protoNum, err := r.af.protoNum(protocol)
protoNum, err := protoToInt(protocol)
if err != nil {
return fmt.Errorf("convert protocol to number: %w", err)
}
@@ -2011,15 +1926,11 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol,
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 2,
Data: binaryutil.BigEndian.PutUint16(originalPort),
Data: binaryutil.BigEndian.PutUint16(sourcePort),
},
}
bits := 32
if localAddr.Is6() {
bits = 128
}
exprs = append(exprs, r.applyPrefix(netip.PrefixFrom(localAddr, bits), false)...)
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
exprs = append(exprs,
&expr.Immediate{
@@ -2028,11 +1939,11 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol,
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(translatedPort),
Data: binaryutil.BigEndian.PutUint16(targetPort),
},
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(r.af.tableFamily),
Family: uint32(nftables.TableFamilyIPv4),
RegAddrMin: 1,
RegProtoMin: 2,
},
@@ -2056,12 +1967,12 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol,
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
rule, exists := r.rules[ruleID]
if !exists {
@@ -2100,44 +2011,45 @@ func (r *router) applyNetwork(
}
if network.IsPrefix() {
return r.applyPrefix(network.Prefix, isSource), nil
return applyPrefix(network.Prefix, isSource), nil
}
return nil, nil
}
// applyPrefix generates nftables expressions for a CIDR prefix
func (r *router) applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any {
// dst offset by default
offset := r.af.dstAddrOffset
func applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any {
// dst offset
offset := uint32(16)
if isSource {
// src offset
offset = r.af.srcAddrOffset
offset = 12
}
ones := prefix.Bits()
// unspecified address (/0) doesn't need extra expressions
// 0.0.0.0/0 doesn't need extra expressions
if ones == 0 {
return nil
}
mask := net.CIDRMask(ones, r.af.totalBits)
xor := make([]byte, r.af.addrLen)
mask := net.CIDRMask(ones, 32)
return []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offset,
Len: r.af.addrLen,
Len: 4,
},
// netmask
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: r.af.addrLen,
Len: 4,
Mask: mask,
Xor: xor,
Xor: []byte{0, 0, 0, 0},
},
// net address
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
@@ -2220,12 +2132,13 @@ func getCtNewExprs() []expr.Any {
}
}
func (r *router) getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
// dst offset by default
offset := r.af.dstAddrOffset
func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
// dst offset
offset := uint32(16)
if isSource {
// src offset
offset = r.af.srcAddrOffset
offset = 12
}
return []expr.Any{
@@ -2233,7 +2146,7 @@ func (r *router) getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool)
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offset,
Len: r.af.addrLen,
Len: 4,
},
&expr.Lookup{
SourceRegister: 1,

View File

@@ -90,9 +90,8 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
}
// Build CIDR matching expressions
testRouter := &router{af: afIPv4}
sourceExp := testRouter.applyPrefix(testCase.InputPair.Source.Prefix, true)
destExp := testRouter.applyPrefix(testCase.InputPair.Destination.Prefix, false)
sourceExp := applyPrefix(testCase.InputPair.Source.Prefix, true)
destExp := applyPrefix(testCase.InputPair.Destination.Prefix, false)
// Combine all expressions in the correct order
// nolint:gocritic
@@ -509,136 +508,6 @@ func TestNftablesCreateIpSet(t *testing.T) {
}
}
func TestNftablesCreateIpSet_IPv6(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTableIPv6()
require.NoError(t, err, "Failed to create v6 work table")
defer deleteWorkTableIPv6()
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable))
defer func() {
require.NoError(t, r.Reset(), "Failed to reset router")
}()
tests := []struct {
name string
sources []netip.Prefix
expected []netip.Prefix
}{
{
name: "Single IPv6",
sources: []netip.Prefix{netip.MustParsePrefix("2001:db8::1/128")},
},
{
name: "Multiple IPv6 Subnets",
sources: []netip.Prefix{
netip.MustParsePrefix("fd00::/64"),
netip.MustParsePrefix("2001:db8::/48"),
netip.MustParsePrefix("fe80::/10"),
},
},
{
name: "Overlapping IPv6",
sources: []netip.Prefix{
netip.MustParsePrefix("fd00::/48"),
netip.MustParsePrefix("fd00::/64"),
netip.MustParsePrefix("fd00::1/128"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("fd00::/48"),
},
},
{
name: "Mixed prefix lengths",
sources: []netip.Prefix{
netip.MustParsePrefix("2001:db8:1::/48"),
netip.MustParsePrefix("2001:db8:2::1/128"),
netip.MustParsePrefix("fd00:abcd::/32"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setName := firewall.NewPrefixSet(tt.sources).HashedName()
set, err := r.createIpSet(setName, setInput{prefixes: tt.sources})
require.NoError(t, err, "Failed to create IPv6 set")
require.NotNil(t, set)
assert.Equal(t, setName, set.Name)
assert.True(t, set.Interval)
assert.Equal(t, nftables.TypeIP6Addr, set.KeyType)
fetchedSet, err := r.conn.GetSetByName(r.workTable, setName)
require.NoError(t, err, "Failed to fetch created set")
elements, err := r.conn.GetSetElements(fetchedSet)
require.NoError(t, err, "Failed to get set elements")
uniquePrefixes := make(map[string]bool)
for _, elem := range elements {
if !elem.IntervalEnd && len(elem.Key) == 16 {
ip := netip.AddrFrom16([16]byte(elem.Key))
uniquePrefixes[ip.String()] = true
}
}
expectedCount := len(tt.expected)
if expectedCount == 0 {
expectedCount = len(tt.sources)
}
assert.Equal(t, expectedCount, len(uniquePrefixes), "unique prefix count mismatch")
r.conn.DelSet(set)
require.NoError(t, r.conn.Flush())
})
}
}
func createWorkTableIPv6() (*nftables.Table, error) {
sConn, err := nftables.New(nftables.AsLasting())
if err != nil {
return nil, err
}
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv6)
if err != nil {
return nil, err
}
for _, t := range tables {
if t.Name == tableNameNetbird {
sConn.DelTable(t)
}
}
table := sConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv6})
err = sConn.Flush()
return table, err
}
func deleteWorkTableIPv6() {
sConn, err := nftables.New(nftables.AsLasting())
if err != nil {
return
}
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv6)
if err != nil {
return
}
for _, t := range tables {
if t.Name == tableNameNetbird {
sConn.DelTable(t)
_ = sConn.Flush()
}
}
}
func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) {
t.Helper()
@@ -758,7 +627,7 @@ func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool {
var metaFound, cmpFound bool
expectedProto, _ := afIPv4.protoNum(proto)
expectedProto, _ := protoToInt(proto)
for _, e := range exprs {
switch ex := e.(type) {
case *expr.Meta:
@@ -985,55 +854,3 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
}
assert.Equal(t, 1, found, "NAT rule should exist in kernel")
}
func TestCalculateLastIP(t *testing.T) {
tests := []struct {
prefix string
want string
}{
{"10.0.0.0/24", "10.0.0.255"},
{"10.0.0.0/32", "10.0.0.0"},
{"0.0.0.0/0", "255.255.255.255"},
{"192.168.1.0/28", "192.168.1.15"},
{"fd00::/64", "fd00::ffff:ffff:ffff:ffff"},
{"fd00::/128", "fd00::"},
{"2001:db8::/48", "2001:db8:0:ffff:ffff:ffff:ffff:ffff"},
{"::/0", "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"},
}
for _, tt := range tests {
t.Run(tt.prefix, func(t *testing.T) {
prefix := netip.MustParsePrefix(tt.prefix)
got := calculateLastIP(prefix)
assert.Equal(t, tt.want, got.String())
})
}
}
func TestConvertPrefixesToSet_IPv6(t *testing.T) {
r := &router{af: afIPv6}
prefixes := []netip.Prefix{
netip.MustParsePrefix("fd00::/64"),
netip.MustParsePrefix("2001:db8::1/128"),
}
elements := r.convertPrefixesToSet(prefixes)
// Each prefix produces 2 elements (start + end)
require.Len(t, elements, 4)
// fd00::/64 start
assert.Equal(t, netip.MustParseAddr("fd00::").As16(), [16]byte(elements[0].Key))
assert.False(t, elements[0].IntervalEnd)
// fd00::/64 end (fd00:0:0:1::, one past the last)
assert.Equal(t, netip.MustParseAddr("fd00:0:0:1::").As16(), [16]byte(elements[1].Key))
assert.True(t, elements[1].IntervalEnd)
// 2001:db8::1/128 start
assert.Equal(t, netip.MustParseAddr("2001:db8::1").As16(), [16]byte(elements[2].Key))
assert.False(t, elements[2].IntervalEnd)
// 2001:db8::1/128 end (2001:db8::2)
assert.Equal(t, netip.MustParseAddr("2001:db8::2").As16(), [16]byte(elements[3].Key))
assert.True(t, elements[3].IntervalEnd)
}

View File

@@ -5,10 +5,8 @@ import (
"os/exec"
"syscall"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -31,20 +29,15 @@ func (m *Manager) Close(*statemanager.Manager) error {
return nil
}
var merr *multierror.Error
if isFirewallRuleActive(firewallRuleName) {
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove windows firewall rule: %w", err))
}
if !isFirewallRuleActive(firewallRuleName) {
return nil
}
if isFirewallRuleActive(firewallRuleName + "-v6") {
if err := manageFirewallRule(firewallRuleName+"-v6", deleteRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove windows v6 firewall rule: %w", err))
}
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
return fmt.Errorf("couldn't remove windows firewall: %w", err)
}
return nberrors.FormatErrorOrNil(merr)
return nil
}
// AllowNetbird allows netbird interface traffic
@@ -53,33 +46,17 @@ func (m *Manager) AllowNetbird() error {
return nil
}
if !isFirewallRuleActive(firewallRuleName) {
if err := manageFirewallRule(firewallRuleName,
addRule,
"dir=in",
"enable=yes",
"action=allow",
"profile=any",
"localip="+m.wgIface.Address().IP.String(),
); err != nil {
return err
}
if isFirewallRuleActive(firewallRuleName) {
return nil
}
if v6 := m.wgIface.Address().IPv6; v6.IsValid() && !isFirewallRuleActive(firewallRuleName+"-v6") {
if err := manageFirewallRule(firewallRuleName+"-v6",
addRule,
"dir=in",
"enable=yes",
"action=allow",
"profile=any",
"localip="+v6.String(),
); err != nil {
return err
}
}
return nil
return manageFirewallRule(firewallRuleName,
addRule,
"dir=in",
"enable=yes",
"action=allow",
"profile=any",
"localip="+m.wgIface.Address().IP.String(),
)
}
func manageFirewallRule(ruleName string, action action, extraArgs ...string) error {

View File

@@ -1,9 +1,8 @@
package conntrack
import (
"net"
"fmt"
"net/netip"
"strconv"
"sync/atomic"
"time"
@@ -65,7 +64,5 @@ type ConnKey struct {
}
func (c ConnKey) String() string {
return net.JoinHostPort(c.SrcIP.Unmap().String(), strconv.Itoa(int(c.SrcPort))) +
" → " +
net.JoinHostPort(c.DstIP.Unmap().String(), strconv.Itoa(int(c.DstPort)))
return fmt.Sprintf("%s:%d → %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
}

View File

@@ -13,54 +13,6 @@ import (
var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
func TestConnKey_String(t *testing.T) {
tests := []struct {
name string
key ConnKey
expect string
}{
{
name: "IPv4",
key: ConnKey{
SrcIP: netip.MustParseAddr("192.168.1.1"),
DstIP: netip.MustParseAddr("10.0.0.1"),
SrcPort: 12345,
DstPort: 80,
},
expect: "192.168.1.1:12345 → 10.0.0.1:80",
},
{
name: "IPv6",
key: ConnKey{
SrcIP: netip.MustParseAddr("2001:db8::1"),
DstIP: netip.MustParseAddr("2001:db8::2"),
SrcPort: 54321,
DstPort: 443,
},
expect: "[2001:db8::1]:54321 → [2001:db8::2]:443",
},
{
name: "IPv4-mapped IPv6 unmaps",
key: ConnKey{
SrcIP: netip.MustParseAddr("::ffff:10.0.0.1"),
DstIP: netip.MustParseAddr("::ffff:10.0.0.2"),
SrcPort: 1000,
DstPort: 2000,
},
expect: "10.0.0.1:1000 → 10.0.0.2:2000",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := tc.key.String()
if got != tc.expect {
t.Errorf("got %q, want %q", got, tc.expect)
}
})
}
}
// Memory pressure tests
func BenchmarkMemoryPressure(b *testing.B) {
b.Run("TCPHighLoad", func(b *testing.B) {

View File

@@ -5,7 +5,6 @@ import (
"fmt"
"net"
"net/netip"
"strconv"
"sync"
"time"
@@ -22,14 +21,9 @@ const (
// ICMPCleanupInterval is how often we check for stale ICMP connections
ICMPCleanupInterval = 15 * time.Second
// MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info.
// IPv4: 20-byte header + 8-byte transport = 28 bytes.
// IPv6: 40-byte header + 8-byte transport = 48 bytes.
MaxICMPPayloadLength = 48
// minICMPPayloadIPv4 is the minimum embedded packet length for IPv4 ICMP errors.
minICMPPayloadIPv4 = 28
// minICMPPayloadIPv6 is the minimum embedded packet length for IPv6 ICMP errors.
minICMPPayloadIPv6 = 48
// MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info,
// which includes the IP header (20 bytes) and transport header (8 bytes)
MaxICMPPayloadLength = 28
)
// ICMPConnKey uniquely identifies an ICMP connection
@@ -71,7 +65,7 @@ type ICMPInfo struct {
// String implements fmt.Stringer for lazy evaluation in log messages
func (info ICMPInfo) String() string {
if info.isErrorMessage() && info.PayloadLen >= minICMPPayloadIPv4 {
if info.isErrorMessage() && info.PayloadLen >= MaxICMPPayloadLength {
if origInfo := info.parseOriginalPacket(); origInfo != "" {
return fmt.Sprintf("%s (original: %s)", info.TypeCode, origInfo)
}
@@ -80,72 +74,42 @@ func (info ICMPInfo) String() string {
return info.TypeCode.String()
}
// isErrorMessage returns true if this ICMP type carries original packet info.
// Covers both ICMPv4 and ICMPv6 error types. Without a family field we match
// both sets; type 3 overlaps (v4 DestUnreachable / v6 TimeExceeded) so it's
// kept as a literal.
// isErrorMessage returns true if this ICMP type carries original packet info
func (info ICMPInfo) isErrorMessage() bool {
typ := info.TypeCode.Type()
// ICMPv4 error types
if typ == layers.ICMPv4TypeDestinationUnreachable ||
typ == layers.ICMPv4TypeRedirect ||
typ == layers.ICMPv4TypeTimeExceeded ||
typ == layers.ICMPv4TypeParameterProblem {
return true
}
// ICMPv6 error types (type 3 already matched above as v4 DestUnreachable)
if typ == layers.ICMPv6TypeDestinationUnreachable ||
typ == layers.ICMPv6TypePacketTooBig ||
typ == layers.ICMPv6TypeParameterProblem {
return true
}
return false
return typ == 3 || // Destination Unreachable
typ == 5 || // Redirect
typ == 11 || // Time Exceeded
typ == 12 // Parameter Problem
}
// parseOriginalPacket extracts info about the original packet from ICMP payload
func (info ICMPInfo) parseOriginalPacket() string {
if info.PayloadLen == 0 {
if info.PayloadLen < MaxICMPPayloadLength {
return ""
}
version := (info.PayloadData[0] >> 4) & 0xF
var protocol uint8
var srcIP, dstIP net.IP
var transportData []byte
switch version {
case 4:
if info.PayloadLen < minICMPPayloadIPv4 {
return ""
}
protocol = info.PayloadData[9]
srcIP = net.IP(info.PayloadData[12:16])
dstIP = net.IP(info.PayloadData[16:20])
transportData = info.PayloadData[20:]
case 6:
if info.PayloadLen < minICMPPayloadIPv6 {
return ""
}
// Next Header field in IPv6 header
protocol = info.PayloadData[6]
srcIP = net.IP(info.PayloadData[8:24])
dstIP = net.IP(info.PayloadData[24:40])
transportData = info.PayloadData[40:]
default:
// TODO: handle IPv6
if version := (info.PayloadData[0] >> 4) & 0xF; version != 4 {
return ""
}
protocol := info.PayloadData[9]
srcIP := net.IP(info.PayloadData[12:16])
dstIP := net.IP(info.PayloadData[16:20])
transportData := info.PayloadData[20:]
switch nftypes.Protocol(protocol) {
case nftypes.TCP:
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
return "TCP " + net.JoinHostPort(srcIP.String(), strconv.Itoa(int(srcPort))) + " → " + net.JoinHostPort(dstIP.String(), strconv.Itoa(int(dstPort)))
return fmt.Sprintf("TCP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
case nftypes.UDP:
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
return "UDP " + net.JoinHostPort(srcIP.String(), strconv.Itoa(int(srcPort))) + " → " + net.JoinHostPort(dstIP.String(), strconv.Itoa(int(dstPort)))
return fmt.Sprintf("UDP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
case nftypes.ICMP:
icmpType := transportData[0]
@@ -283,10 +247,9 @@ func (t *ICMPTracker) track(
t.sendEvent(nftypes.TypeStart, conn, ruleId)
}
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request.
// Accepts both ICMPv4 (type 0) and ICMPv6 (type 129) echo replies.
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool {
if icmpType != uint8(layers.ICMPv4TypeEchoReply) && icmpType != uint8(layers.ICMPv6TypeEchoReply) {
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
return false
}
@@ -338,13 +301,6 @@ func (t *ICMPTracker) cleanup() {
}
}
func icmpProtocolForAddr(ip netip.Addr) nftypes.Protocol {
if ip.Is6() {
return nftypes.ICMPv6
}
return nftypes.ICMP
}
// Close stops the cleanup routine and releases resources
func (t *ICMPTracker) Close() {
t.tickerCancel()
@@ -360,7 +316,7 @@ func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack, ruleID []
Type: typ,
RuleID: ruleID,
Direction: conn.Direction,
Protocol: icmpProtocolForAddr(conn.SourceIP),
Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
SourceIP: conn.SourceIP,
DestIP: conn.DestIP,
ICMPType: conn.ICMPType,
@@ -378,7 +334,7 @@ func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Ad
Type: nftypes.TypeStart,
RuleID: ruleID,
Direction: direction,
Protocol: icmpProtocolForAddr(srcIP),
Protocol: nftypes.ICMP,
SourceIP: srcIP,
DestIP: dstIP,
ICMPType: typ,

View File

@@ -5,42 +5,6 @@ import (
"testing"
)
func TestICMPConnKey_String(t *testing.T) {
tests := []struct {
name string
key ICMPConnKey
expect string
}{
{
name: "IPv4",
key: ICMPConnKey{
SrcIP: netip.MustParseAddr("192.168.1.1"),
DstIP: netip.MustParseAddr("10.0.0.1"),
ID: 1234,
},
expect: "192.168.1.1 → 10.0.0.1 (id 1234)",
},
{
name: "IPv6",
key: ICMPConnKey{
SrcIP: netip.MustParseAddr("2001:db8::1"),
DstIP: netip.MustParseAddr("2001:db8::2"),
ID: 5678,
},
expect: "2001:db8::1 → 2001:db8::2 (id 5678)",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := tc.key.String()
if got != tc.expect {
t.Errorf("got %q, want %q", got, tc.expect)
}
})
}
}
func BenchmarkICMPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)

View File

@@ -18,10 +18,9 @@ import (
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/google/uuid"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
@@ -36,10 +35,8 @@ import (
const (
layerTypeAll = 255
// ipv4TCPHeaderMinSize represents minimum IPv4 (20) + TCP (20) header size for MSS calculation
ipv4TCPHeaderMinSize = 40
// ipv6TCPHeaderMinSize represents minimum IPv6 (40) + TCP (20) header size for MSS calculation
ipv6TCPHeaderMinSize = 60
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40
)
// serviceKey represents a protocol/port combination for netstack service registry
@@ -126,7 +123,7 @@ type Manager struct {
logger *nblog.Logger
flowLogger nftypes.FlowLogger
blockRules []firewall.Rule
blockRule firewall.Rule
// Internal 1:1 DNAT
dnatEnabled atomic.Bool
@@ -141,10 +138,9 @@ type Manager struct {
netstackServices map[serviceKey]struct{}
netstackServiceMutex sync.RWMutex
mtu uint16
mssClampValueIPv4 uint16
mssClampValueIPv6 uint16
mssClampEnabled bool
mtu uint16
mssClampValue uint16
mssClampEnabled bool
// Only one hook per protocol is supported. Outbound direction only.
udpHookOut atomic.Pointer[common.PacketHook]
@@ -161,28 +157,11 @@ type decoder struct {
icmp4 layers.ICMPv4
icmp6 layers.ICMPv6
decoded []gopacket.LayerType
parser4 *gopacket.DecodingLayerParser
parser6 *gopacket.DecodingLayerParser
parser *gopacket.DecodingLayerParser
dnatOrigPort uint16
}
// decodePacket decodes packet data using the appropriate parser based on IP version.
func (d *decoder) decodePacket(data []byte) error {
if len(data) == 0 {
return errors.New("empty packet")
}
version := data[0] >> 4
switch version {
case 4:
return d.parser4.DecodeLayers(data, &d.decoded)
case 6:
return d.parser6.DecodeLayers(data, &d.decoded)
default:
return fmt.Errorf("unknown IP version %d", version)
}
}
// Create userspace firewall manager constructor
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
return create(iface, nil, disableServerRoutes, flowLogger, mtu)
@@ -240,17 +219,11 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
d := &decoder{
decoded: []gopacket.LayerType{},
}
d.parser4 = gopacket.NewDecodingLayerParser(
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser4.IgnoreUnsupported = true
d.parser6 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv6,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser6.IgnoreUnsupported = true
d.parser.IgnoreUnsupported = true
return d
},
},
@@ -276,12 +249,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
if !disableMSSClamping {
m.mssClampEnabled = true
if mtu > ipv4TCPHeaderMinSize {
m.mssClampValueIPv4 = mtu - ipv4TCPHeaderMinSize
}
if mtu > ipv6TCPHeaderMinSize {
m.mssClampValueIPv6 = mtu - ipv6TCPHeaderMinSize
}
m.mssClampValue = mtu - ipTCPHeaderMinSize
}
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
return nil, fmt.Errorf("update local IPs: %w", err)
@@ -304,25 +272,13 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
return m, nil
}
// blockInvalidRouted installs drop rules for traffic to the wg overlay that
// arrives via the routing path. v4 and v6 are independent: a v6 install
// failure leaves v4 protection in place (and vice versa) so the returned
// slice always contains whatever was successfully installed, even on error.
// Callers must persist the slice so DisableRouting can clean partial state.
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) ([]firewall.Rule, error) {
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) {
wgPrefix := iface.Address().Network
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
sources := []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
v6Net := iface.Address().IPv6Net
if v6Net.IsValid() {
sources = append(sources, netip.PrefixFrom(netip.IPv6Unspecified(), 0))
}
var rules []firewall.Rule
v4Rule, err := m.addRouteFiltering(
rule, err := m.addRouteFiltering(
nil,
sources,
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
firewall.Network{Prefix: wgPrefix},
firewall.ProtocolALL,
nil,
@@ -330,30 +286,12 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) ([]firewall.Rule,
firewall.ActionDrop,
)
if err != nil {
return rules, fmt.Errorf("block wg v4 net: %w", err)
}
rules = append(rules, v4Rule)
if v6Net.IsValid() {
log.Debugf("blocking invalid routed traffic for %s", v6Net)
v6Rule, err := m.addRouteFiltering(
nil,
sources,
firewall.Network{Prefix: v6Net},
firewall.ProtocolALL,
nil,
nil,
firewall.ActionDrop,
)
if err != nil {
return rules, fmt.Errorf("block wg v6 net: %w", err)
}
rules = append(rules, v6Rule)
return nil, fmt.Errorf("block wg nte : %w", err)
}
// TODO: Block networks that we're a client of
return rules, nil
return rule, nil
}
func (m *Manager) determineRouting() error {
@@ -583,7 +521,7 @@ func (m *Manager) addRouteFiltering(
mgmtId: id,
sources: sources,
dstSet: destination.Set,
protoLayer: protoToLayer(proto, ipLayerFromPrefix(destination.Prefix)),
protoLayer: protoToLayer(proto, layers.LayerTypeIPv4),
srcPort: sPort,
dstPort: dPort,
action: action,
@@ -674,10 +612,10 @@ func (m *Manager) Flush() error { return nil }
// resetState clears all firewall rules and closes connection trackers.
// Must be called with m.mutex held.
func (m *Manager) resetState() {
clear(m.outgoingRules)
clear(m.incomingDenyRules)
clear(m.incomingRules)
clear(m.routeRulesMap)
maps.Clear(m.outgoingRules)
maps.Clear(m.incomingDenyRules)
maps.Clear(m.incomingRules)
maps.Clear(m.routeRulesMap)
m.routeRules = m.routeRules[:0]
m.udpHookOut.Store(nil)
m.tcpHookOut.Store(nil)
@@ -738,7 +676,11 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
}
destinations := matches[0].destinations
destinations = append(destinations, prefixes...)
for _, prefix := range prefixes {
if prefix.Addr().Is4() {
destinations = append(destinations, prefix)
}
}
slices.SortFunc(destinations, func(a, b netip.Prefix) int {
cmp := a.Addr().Compare(b.Addr())
@@ -777,7 +719,7 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
if err := d.decodePacket(packetData); err != nil {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
return false
}
@@ -861,32 +803,12 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
return false
}
var mssClampValue uint16
var ipHeaderSize int
switch d.decoded[0] {
case layers.LayerTypeIPv4:
mssClampValue = m.mssClampValueIPv4
ipHeaderSize = int(d.ip4.IHL) * 4
if ipHeaderSize < 20 {
return false
}
case layers.LayerTypeIPv6:
mssClampValue = m.mssClampValueIPv6
ipHeaderSize = 40
default:
return false
}
if mssClampValue == 0 {
return false
}
mssOptionIndex := -1
var currentMSS uint16
for i, opt := range d.tcp.Options {
if opt.OptionType == layers.TCPOptionKindMSS && len(opt.OptionData) == 2 {
currentMSS = binary.BigEndian.Uint16(opt.OptionData)
if currentMSS > mssClampValue {
if currentMSS > m.mssClampValue {
mssOptionIndex = i
break
}
@@ -897,15 +819,20 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
return false
}
if !m.updateMSSOption(packetData, d, mssOptionIndex, mssClampValue, ipHeaderSize) {
ipHeaderSize := int(d.ip4.IHL) * 4
if ipHeaderSize < 20 {
return false
}
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, mssClampValue)
if !m.updateMSSOption(packetData, d, mssOptionIndex, ipHeaderSize) {
return false
}
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue)
return true
}
func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex int, mssClampValue uint16, ipHeaderSize int) bool {
func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex, ipHeaderSize int) bool {
tcpHeaderStart := ipHeaderSize
tcpOptionsStart := tcpHeaderStart + 20
@@ -920,7 +847,7 @@ func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex
}
mssValueOffset := optOffset + 2
binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], mssClampValue)
binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], m.mssClampValue)
m.recalculateTCPChecksum(packetData, d, tcpHeaderStart)
return true
@@ -930,32 +857,18 @@ func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeade
tcpLayer := packetData[tcpHeaderStart:]
tcpLength := len(packetData) - tcpHeaderStart
// Zero out existing checksum
tcpLayer[16] = 0
tcpLayer[17] = 0
// Build pseudo-header checksum based on IP version
var pseudoSum uint32
switch d.decoded[0] {
case layers.LayerTypeIPv4:
pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1])
pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3])
pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1])
pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3])
pseudoSum += uint32(d.ip4.Protocol)
pseudoSum += uint32(tcpLength)
case layers.LayerTypeIPv6:
for i := 0; i < 16; i += 2 {
pseudoSum += uint32(d.ip6.SrcIP[i])<<8 | uint32(d.ip6.SrcIP[i+1])
}
for i := 0; i < 16; i += 2 {
pseudoSum += uint32(d.ip6.DstIP[i])<<8 | uint32(d.ip6.DstIP[i+1])
}
pseudoSum += uint32(tcpLength)
pseudoSum += uint32(layers.IPProtocolTCP)
}
pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1])
pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3])
pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1])
pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3])
pseudoSum += uint32(d.ip4.Protocol)
pseudoSum += uint32(tcpLength)
sum := pseudoSum
var sum = pseudoSum
for i := 0; i < tcpLength-1; i += 2 {
sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1])
}
@@ -993,9 +906,6 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData
}
case layers.LayerTypeICMPv4:
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
case layers.LayerTypeICMPv6:
id, tc := icmpv6EchoFields(d)
m.icmpTracker.TrackOutbound(srcIP, dstIP, id, tc, d.icmp6.Payload, size)
}
}
@@ -1009,9 +919,6 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort)
case layers.LayerTypeICMPv4:
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
case layers.LayerTypeICMPv6:
id, tc := icmpv6EchoFields(d)
m.icmpTracker.TrackInbound(srcIP, dstIP, id, tc, ruleID, d.icmp6.Payload, size)
}
d.dnatOrigPort = 0
@@ -1044,19 +951,15 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
// TODO: pass fragments of routed packets to forwarder
if fragment {
if d.decoded[0] == layers.LayerTypeIPv4 {
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v",
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
} else {
m.logger.Trace2("packet is an IPv6 fragment: src=%v dst=%v", srcIP, dstIP)
}
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v",
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
return false
}
// TODO: optimize port DNAT by caching matched rules in conntrack
if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated {
// Re-decode after port DNAT translation to update port information
if err := d.decodePacket(packetData); err != nil {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Error1("failed to re-decode packet after port DNAT: %v", err)
return true
}
@@ -1065,7 +968,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
if translated := m.translateInboundReverse(packetData, d); translated {
// Re-decode after translation to get original addresses
if err := d.decodePacket(packetData); err != nil {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err)
return true
}
@@ -1197,48 +1100,6 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
return true
}
// icmpv6EchoFields extracts the echo identifier from an ICMPv6 packet and maps
// the ICMPv6 type code to an ICMPv4TypeCode so the ICMP conntrack can handle
// both families uniformly. The echo ID is in the first two payload bytes.
func icmpv6EchoFields(d *decoder) (id uint16, tc layers.ICMPv4TypeCode) {
if len(d.icmp6.Payload) >= 2 {
id = uint16(d.icmp6.Payload[0])<<8 | uint16(d.icmp6.Payload[1])
}
// Map ICMPv6 echo types to ICMPv4 equivalents for unified tracking.
switch d.icmp6.TypeCode.Type() {
case layers.ICMPv6TypeEchoRequest:
tc = layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0)
case layers.ICMPv6TypeEchoReply:
tc = layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoReply, 0)
default:
tc = layers.CreateICMPv4TypeCode(d.icmp6.TypeCode.Type(), d.icmp6.TypeCode.Code())
}
return id, tc
}
// protoLayerMatches checks if a packet's protocol layer matches a rule's expected
// protocol layer. ICMPv4 and ICMPv6 are treated as equivalent when matching
// ICMP rules since management sends a single ICMP rule for both families.
func protoLayerMatches(ruleLayer, packetLayer gopacket.LayerType) bool {
if ruleLayer == packetLayer {
return true
}
if ruleLayer == layers.LayerTypeICMPv4 && packetLayer == layers.LayerTypeICMPv6 {
return true
}
if ruleLayer == layers.LayerTypeICMPv6 && packetLayer == layers.LayerTypeICMPv4 {
return true
}
return false
}
func ipLayerFromPrefix(p netip.Prefix) gopacket.LayerType {
if p.Addr().Is6() {
return layers.LayerTypeIPv6
}
return layers.LayerTypeIPv4
}
func protoToLayer(proto firewall.Protocol, ipLayer gopacket.LayerType) gopacket.LayerType {
switch proto {
case firewall.ProtocolTCP:
@@ -1262,10 +1123,8 @@ func getProtocolFromPacket(d *decoder) nftypes.Protocol {
return nftypes.TCP
case layers.LayerTypeUDP:
return nftypes.UDP
case layers.LayerTypeICMPv4:
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return nftypes.ICMP
case layers.LayerTypeICMPv6:
return nftypes.ICMPv6
default:
return nftypes.ProtocolUnknown
}
@@ -1286,7 +1145,7 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
// It returns true, false if the packet is valid and not a fragment.
// It returns true, true if the packet is a fragment and valid.
func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
if err := d.decodePacket(packetData); err != nil {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Trace1("couldn't decode packet, err: %s", err)
return false, false
}
@@ -1299,21 +1158,10 @@ func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
}
// Fragments are also valid
if l == 1 {
switch d.decoded[0] {
case layers.LayerTypeIPv4:
if d.ip4.Flags&layers.IPv4MoreFragments != 0 || d.ip4.FragOffset != 0 {
return true, true
}
case layers.LayerTypeIPv6:
// IPv6 uses Fragment extension header (NextHeader=44). If gopacket
// only decoded the IPv6 layer, the transport is in a fragment.
// TODO: handle non-Fragment extension headers (HopByHop, Routing,
// DestOpts) by walking the chain. gopacket's parser does not
// support them as DecodingLayers; today we drop such packets.
if d.ip6.NextHeader == layers.IPProtocolIPv6Fragment {
return true, true
}
if l == 1 && d.decoded[0] == layers.LayerTypeIPv4 {
ip4 := d.ip4
if ip4.Flags&layers.IPv4MoreFragments != 0 || ip4.FragOffset != 0 {
return true, true
}
}
@@ -1351,35 +1199,21 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr,
size,
)
case layers.LayerTypeICMPv6:
id, _ := icmpv6EchoFields(d)
return m.icmpTracker.IsValidInbound(
srcIP,
dstIP,
id,
d.icmp6.TypeCode.Type(),
size,
)
// TODO: ICMPv6
}
return false
}
// isSpecialICMP returns true if the packet is a special ICMP error packet that should be allowed.
// isSpecialICMP returns true if the packet is a special ICMP packet that should be allowed
func (m *Manager) isSpecialICMP(d *decoder) bool {
switch d.decoded[1] {
case layers.LayerTypeICMPv4:
icmpType := d.icmp4.TypeCode.Type()
return icmpType == layers.ICMPv4TypeDestinationUnreachable ||
icmpType == layers.ICMPv4TypeTimeExceeded
case layers.LayerTypeICMPv6:
icmpType := d.icmp6.TypeCode.Type()
return icmpType == layers.ICMPv6TypeDestinationUnreachable ||
icmpType == layers.ICMPv6TypePacketTooBig ||
icmpType == layers.ICMPv6TypeTimeExceeded ||
icmpType == layers.ICMPv6TypeParameterProblem
if d.decoded[1] != layers.LayerTypeICMPv4 {
return false
}
return false
icmpType := d.icmp4.TypeCode.Type()
return icmpType == layers.ICMPv4TypeDestinationUnreachable ||
icmpType == layers.ICMPv4TypeTimeExceeded
}
func (m *Manager) peerACLsBlock(srcIP netip.Addr, d *decoder, packetData []byte) ([]byte, bool) {
@@ -1436,7 +1270,7 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
return rule.mgmtId, rule.drop, true
}
if !protoLayerMatches(rule.protoLayer, payloadLayer) {
if payloadLayer != rule.protoLayer {
continue
}
@@ -1471,7 +1305,8 @@ func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.Lay
}
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) bool {
if rule.protoLayer != layerTypeAll && !protoLayerMatches(rule.protoLayer, protoLayer) {
// TODO: handle ipv6 vs ipv4 icmp rules
if rule.protoLayer != layerTypeAll && rule.protoLayer != protoLayer {
return false
}
@@ -1532,14 +1367,13 @@ func (m *Manager) EnableRouting() error {
return nil
}
rules, err := m.blockInvalidRouted(m.wgIface)
// Persist whatever was installed even on partial failure, so DisableRouting
// can clean it up later.
m.blockRules = rules
rule, err := m.blockInvalidRouted(m.wgIface)
if err != nil {
return fmt.Errorf("block invalid routed: %w", err)
}
m.blockRule = rule
return nil
}
@@ -1555,16 +1389,9 @@ func (m *Manager) DisableRouting() error {
m.routingEnabled.Store(false)
m.nativeRouter.Store(false)
var merr *multierror.Error
for _, rule := range m.blockRules {
if err := m.deleteRouteRule(rule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete block rule: %w", err))
}
}
m.blockRules = nil
// don't stop forwarder if in use by netstack
if m.netstack && m.localForwarding {
return nberrors.FormatErrorOrNil(merr)
return nil
}
fwder.Stop()
@@ -1572,7 +1399,14 @@ func (m *Manager) DisableRouting() error {
log.Debug("forwarder stopped")
return nberrors.FormatErrorOrNil(merr)
if m.blockRule != nil {
if err := m.deleteRouteRule(m.blockRule); err != nil {
return fmt.Errorf("delete block rule: %w", err)
}
m.blockRule = nil
}
return nil
}
// RegisterNetstackService registers a service as listening on the netstack for the given protocol and port
@@ -1626,8 +1460,7 @@ func (m *Manager) shouldForward(d *decoder, dstIP netip.Addr) bool {
}
// traffic to our other local interfaces (not NetBird IP) - always forward
addr := m.wgIface.Address()
if dstIP != addr.IP && (!addr.IPv6.IsValid() || dstIP != addr.IPv6) {
if dstIP != m.wgIface.Address().IP {
return true
}

View File

@@ -1023,8 +1023,7 @@ func BenchmarkMSSClamping(b *testing.B) {
}()
manager.mssClampEnabled = true
manager.mssClampValueIPv4 = 1240
manager.mssClampValueIPv6 = 1220
manager.mssClampValue = 1240
srcIP := net.ParseIP("100.64.0.2")
dstIP := net.ParseIP("8.8.8.8")
@@ -1089,8 +1088,7 @@ func BenchmarkMSSClampingOverhead(b *testing.B) {
manager.mssClampEnabled = sc.enabled
if sc.enabled {
manager.mssClampValueIPv4 = 1240
manager.mssClampValueIPv6 = 1220
manager.mssClampValue = 1240
}
srcIP := net.ParseIP("100.64.0.2")
@@ -1143,8 +1141,7 @@ func BenchmarkMSSClampingMemory(b *testing.B) {
}()
manager.mssClampEnabled = true
manager.mssClampValueIPv4 = 1240
manager.mssClampValueIPv6 = 1220
manager.mssClampValue = 1240
srcIP := net.ParseIP("100.64.0.2")
dstIP := net.ParseIP("8.8.8.8")

View File

@@ -539,236 +539,53 @@ func TestPeerACLFiltering(t *testing.T) {
}
}
func TestPeerACLFilteringIPv6(t *testing.T) {
localIP := netip.MustParseAddr("100.10.0.100")
localIPv6 := netip.MustParseAddr("fd00::100")
wgNet := netip.MustParsePrefix("100.10.0.0/16")
wgNetV6 := netip.MustParsePrefix("fd00::/64")
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: localIP,
Network: wgNet,
IPv6: localIPv6,
IPv6Net: wgNetV6,
}
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
err = manager.UpdateLocalIPs()
require.NoError(t, err)
testCases := []struct {
name string
srcIP string
dstIP string
proto fw.Protocol
srcPort uint16
dstPort uint16
ruleIP string
ruleProto fw.Protocol
ruleDstPort *fw.Port
ruleAction fw.Action
shouldBeBlocked bool
}{
{
name: "IPv6: allow TCP from peer",
srcIP: "fd00::1",
dstIP: "fd00::100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
ruleIP: "fd00::1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{443}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "IPv6: allow UDP from peer",
srcIP: "fd00::1",
dstIP: "fd00::100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 53,
ruleIP: "fd00::1",
ruleProto: fw.ProtocolUDP,
ruleDstPort: &fw.Port{Values: []uint16{53}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "IPv6: allow ICMPv6 from peer",
srcIP: "fd00::1",
dstIP: "fd00::100",
proto: fw.ProtocolICMP,
ruleIP: "fd00::1",
ruleProto: fw.ProtocolICMP,
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "IPv6: block TCP without rule",
srcIP: "fd00::2",
dstIP: "fd00::100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
ruleIP: "fd00::1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{443}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: true,
},
{
name: "IPv6: drop rule",
srcIP: "fd00::1",
dstIP: "fd00::100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 22,
ruleIP: "fd00::1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{22}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "IPv6: allow all protocols",
srcIP: "fd00::1",
dstIP: "fd00::100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 9999,
ruleIP: "fd00::1",
ruleProto: fw.ProtocolALL,
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "IPv6: v4 wildcard ICMP rule matches ICMPv6 via protoLayerMatches",
srcIP: "fd00::1",
dstIP: "fd00::100",
proto: fw.ProtocolICMP,
ruleIP: "0.0.0.0",
ruleProto: fw.ProtocolICMP,
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
}
t.Run("IPv6 implicit DROP (no rules)", func(t *testing.T) {
packet := createTestPacket(t, "fd00::1", "fd00::100", fw.ProtocolTCP, 12345, 443)
isDropped := manager.FilterInbound(packet, 0)
require.True(t, isDropped, "IPv6 packet should be dropped when no rules exist")
})
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.ruleAction == fw.ActionDrop {
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
require.NoError(t, err)
t.Cleanup(func() {
for _, rule := range rules {
require.NoError(t, manager.DeletePeerRule(rule))
}
})
}
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), tc.ruleProto, nil, tc.ruleDstPort, tc.ruleAction, "")
require.NoError(t, err)
require.NotEmpty(t, rules)
t.Cleanup(func() {
for _, rule := range rules {
require.NoError(t, manager.DeletePeerRule(rule))
}
})
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
isDropped := manager.FilterInbound(packet, 0)
require.Equal(t, tc.shouldBeBlocked, isDropped, "packet filter result mismatch")
})
}
}
func createTestPacket(t *testing.T, srcIP, dstIP string, proto fw.Protocol, srcPort, dstPort uint16) []byte {
t.Helper()
src := net.ParseIP(srcIP)
dst := net.ParseIP(dstIP)
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
// Detect address family
isV6 := src.To4() == nil
ipLayer := &layers.IPv4{
Version: 4,
TTL: 64,
SrcIP: net.ParseIP(srcIP),
DstIP: net.ParseIP(dstIP),
}
var err error
switch proto {
case fw.ProtocolTCP:
ipLayer.Protocol = layers.IPProtocolTCP
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
}
err = tcp.SetNetworkLayerForChecksum(ipLayer)
require.NoError(t, err)
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcp)
if isV6 {
ip6 := &layers.IPv6{
Version: 6,
HopLimit: 64,
SrcIP: src,
DstIP: dst,
case fw.ProtocolUDP:
ipLayer.Protocol = layers.IPProtocolUDP
udp := &layers.UDP{
SrcPort: layers.UDPPort(srcPort),
DstPort: layers.UDPPort(dstPort),
}
err = udp.SetNetworkLayerForChecksum(ipLayer)
require.NoError(t, err)
err = gopacket.SerializeLayers(buf, opts, ipLayer, udp)
switch proto {
case fw.ProtocolTCP:
ip6.NextHeader = layers.IPProtocolTCP
tcp := &layers.TCP{SrcPort: layers.TCPPort(srcPort), DstPort: layers.TCPPort(dstPort)}
_ = tcp.SetNetworkLayerForChecksum(ip6)
err = gopacket.SerializeLayers(buf, opts, ip6, tcp)
case fw.ProtocolUDP:
ip6.NextHeader = layers.IPProtocolUDP
udp := &layers.UDP{SrcPort: layers.UDPPort(srcPort), DstPort: layers.UDPPort(dstPort)}
_ = udp.SetNetworkLayerForChecksum(ip6)
err = gopacket.SerializeLayers(buf, opts, ip6, udp)
case fw.ProtocolICMP:
ip6.NextHeader = layers.IPProtocolICMPv6
icmp := &layers.ICMPv6{
TypeCode: layers.CreateICMPv6TypeCode(layers.ICMPv6TypeEchoRequest, 0),
}
_ = icmp.SetNetworkLayerForChecksum(ip6)
err = gopacket.SerializeLayers(buf, opts, ip6, icmp)
default:
err = gopacket.SerializeLayers(buf, opts, ip6)
}
} else {
ip4 := &layers.IPv4{
Version: 4,
TTL: 64,
SrcIP: src,
DstIP: dst,
case fw.ProtocolICMP:
ipLayer.Protocol = layers.IPProtocolICMPv4
icmp := &layers.ICMPv4{
TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0),
}
err = gopacket.SerializeLayers(buf, opts, ipLayer, icmp)
switch proto {
case fw.ProtocolTCP:
ip4.Protocol = layers.IPProtocolTCP
tcp := &layers.TCP{SrcPort: layers.TCPPort(srcPort), DstPort: layers.TCPPort(dstPort)}
_ = tcp.SetNetworkLayerForChecksum(ip4)
err = gopacket.SerializeLayers(buf, opts, ip4, tcp)
case fw.ProtocolUDP:
ip4.Protocol = layers.IPProtocolUDP
udp := &layers.UDP{SrcPort: layers.UDPPort(srcPort), DstPort: layers.UDPPort(dstPort)}
_ = udp.SetNetworkLayerForChecksum(ip4)
err = gopacket.SerializeLayers(buf, opts, ip4, udp)
case fw.ProtocolICMP:
ip4.Protocol = layers.IPProtocolICMPv4
icmp := &layers.ICMPv4{TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0)}
err = gopacket.SerializeLayers(buf, opts, ip4, icmp)
default:
err = gopacket.SerializeLayers(buf, opts, ip4)
}
default:
err = gopacket.SerializeLayers(buf, opts, ipLayer)
}
require.NoError(t, err)
@@ -1681,103 +1498,3 @@ func TestRouteACLSet(t *testing.T) {
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
}
// TestRouteACLFilteringIPv6 tests IPv6 route ACL matching directly via routeACLsPass.
// Note: full FilterInbound for routed IPv6 traffic drops at the forwarder stage (IPv4-only)
// but the ACL decision itself is correct.
func TestRouteACLFilteringIPv6(t *testing.T) {
manager := setupRoutedManager(t, "10.10.0.100/16")
v6Dst := netip.MustParsePrefix("fd00:dead:beef::/48")
_, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
fw.Network{Prefix: v6Dst},
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{80}},
fw.ActionAccept,
)
require.NoError(t, err)
_, err = manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
fw.Network{Prefix: netip.MustParsePrefix("fd00:dead:beef:1::/64")},
fw.ProtocolALL,
nil,
nil,
fw.ActionDrop,
)
require.NoError(t, err)
tests := []struct {
name string
srcIP netip.Addr
dstIP netip.Addr
proto gopacket.LayerType
srcPort uint16
dstPort uint16
allowed bool
}{
{
name: "IPv6 TCP to allowed dest",
srcIP: netip.MustParseAddr("fd00::1"),
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
proto: layers.LayerTypeTCP,
srcPort: 12345,
dstPort: 80,
allowed: true,
},
{
name: "IPv6 TCP wrong port",
srcIP: netip.MustParseAddr("fd00::1"),
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
proto: layers.LayerTypeTCP,
srcPort: 12345,
dstPort: 443,
allowed: false,
},
{
name: "IPv6 UDP not matched by TCP rule",
srcIP: netip.MustParseAddr("fd00::1"),
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
proto: layers.LayerTypeUDP,
srcPort: 12345,
dstPort: 80,
allowed: false,
},
{
name: "IPv6 ICMPv6 matches ICMP rule via protoLayerMatches",
srcIP: netip.MustParseAddr("fd00::1"),
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
proto: layers.LayerTypeICMPv6,
allowed: false,
},
{
name: "IPv6 to denied subnet",
srcIP: netip.MustParseAddr("fd00::1"),
dstIP: netip.MustParseAddr("fd00:dead:beef:1::1"),
proto: layers.LayerTypeTCP,
srcPort: 12345,
dstPort: 80,
allowed: false,
},
{
name: "IPv6 source outside allowed range",
srcIP: netip.MustParseAddr("fe80::1"),
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
proto: layers.LayerTypeTCP,
srcPort: 12345,
dstPort: 80,
allowed: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, pass := manager.routeACLsPass(tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
require.Equal(t, tc.allowed, pass, "route ACL result mismatch")
})
}
}

View File

@@ -189,21 +189,21 @@ func TestBlockInvalidRoutedIdempotent(t *testing.T) {
})
// Call blockInvalidRouted directly multiple times
rules1, err := manager.blockInvalidRouted(ifaceMock)
rule1, err := manager.blockInvalidRouted(ifaceMock)
require.NoError(t, err)
require.NotEmpty(t, rules1)
require.NotNil(t, rule1)
rules2, err := manager.blockInvalidRouted(ifaceMock)
rule2, err := manager.blockInvalidRouted(ifaceMock)
require.NoError(t, err)
require.NotEmpty(t, rules2)
require.NotNil(t, rule2)
rules3, err := manager.blockInvalidRouted(ifaceMock)
rule3, err := manager.blockInvalidRouted(ifaceMock)
require.NoError(t, err)
require.NotEmpty(t, rules3)
require.NotNil(t, rule3)
// All calls should return the same v4 block rule (idempotent install).
assert.Equal(t, rules1[0].ID(), rules2[0].ID(), "Second call should return same v4 rule")
assert.Equal(t, rules2[0].ID(), rules3[0].ID(), "Third call should return same v4 rule")
// All should return the same rule
assert.Equal(t, rule1.ID(), rule2.ID(), "Second call should return same rule")
assert.Equal(t, rule2.ID(), rule3.ID(), "Third call should return same rule")
// Should have exactly 1 route rule
manager.mutex.RLock()

View File

@@ -535,16 +535,11 @@ func TestProcessOutgoingHooks(t *testing.T) {
d := &decoder{
decoded: []gopacket.LayerType{},
}
d.parser4 = gopacket.NewDecodingLayerParser(
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser4.IgnoreUnsupported = true
d.parser6 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv6,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser6.IgnoreUnsupported = true
d.parser.IgnoreUnsupported = true
return d
},
}
@@ -643,16 +638,11 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
d := &decoder{
decoded: []gopacket.LayerType{},
}
d.parser4 = gopacket.NewDecodingLayerParser(
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser4.IgnoreUnsupported = true
d.parser6 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv6,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser6.IgnoreUnsupported = true
d.parser.IgnoreUnsupported = true
return d
},
}
@@ -1058,8 +1048,8 @@ func TestMSSClamping(t *testing.T) {
}()
require.True(t, manager.mssClampEnabled, "MSS clamping should be enabled by default")
require.Equal(t, uint16(1280-ipv4TCPHeaderMinSize), manager.mssClampValueIPv4, "IPv4 MSS clamp value should be MTU - 40")
require.Equal(t, uint16(1280-ipv6TCPHeaderMinSize), manager.mssClampValueIPv6, "IPv6 MSS clamp value should be MTU - 60")
expectedMSSValue := uint16(1280 - ipTCPHeaderMinSize)
require.Equal(t, expectedMSSValue, manager.mssClampValue, "MSS clamp value should be MTU - 40")
err = manager.UpdateLocalIPs()
require.NoError(t, err)
@@ -1077,7 +1067,7 @@ func TestMSSClamping(t *testing.T) {
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
require.Equal(t, uint8(layers.TCPOptionKindMSS), uint8(d.tcp.Options[0].OptionType))
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
require.Equal(t, manager.mssClampValueIPv4, actualMSS, "MSS should be clamped to MTU - 40")
require.Equal(t, expectedMSSValue, actualMSS, "MSS should be clamped to MTU - 40")
})
t.Run("SYN packet with low MSS unchanged", func(t *testing.T) {
@@ -1101,7 +1091,7 @@ func TestMSSClamping(t *testing.T) {
d := parsePacket(t, packet)
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
require.Equal(t, manager.mssClampValueIPv4, actualMSS, "MSS in SYN-ACK should be clamped")
require.Equal(t, expectedMSSValue, actualMSS, "MSS in SYN-ACK should be clamped")
})
t.Run("Non-SYN packet unchanged", func(t *testing.T) {
@@ -1273,18 +1263,13 @@ func TestShouldForward(t *testing.T) {
d := &decoder{
decoded: []gopacket.LayerType{},
}
d.parser4 = gopacket.NewDecodingLayerParser(
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser4.IgnoreUnsupported = true
d.parser6 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv6,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser6.IgnoreUnsupported = true
d.parser.IgnoreUnsupported = true
err = d.decodePacket(buf.Bytes())
err = d.parser.DecodeLayers(buf.Bytes(), &d.decoded)
require.NoError(t, err)
return d
@@ -1344,44 +1329,6 @@ func TestShouldForward(t *testing.T) {
},
}
// Add IPv6 to the interface and test dual-stack cases
wgIPv6 := netip.MustParseAddr("fd00::1")
otherIPv6 := netip.MustParseAddr("fd00::2")
ifaceMock.AddressFunc = func() wgaddr.Address {
return wgaddr.Address{
IP: wgIP,
Network: netip.PrefixFrom(wgIP, 24),
IPv6: wgIPv6,
IPv6Net: netip.PrefixFrom(wgIPv6, 64),
}
}
// Re-create manager to pick up the new address with IPv6
require.NoError(t, manager.Close(nil))
manager, err = Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
v6Cases := []struct {
name string
dstIP netip.Addr
expected bool
description string
}{
{"v6 traffic to other address", otherIPv6, true, "should forward v6 traffic not destined to our v6 address"},
{"v6 traffic to our v6 IP", wgIPv6, false, "should not forward traffic destined to our v6 address"},
{"v4 traffic to other with v6 configured", otherIP, true, "should forward v4 traffic when v6 configured"},
{"v4 traffic to our v4 IP with v6 configured", wgIP, false, "should not forward traffic to our v4 address"},
}
for _, tt := range v6Cases {
t.Run(tt.name, func(t *testing.T) {
manager.localForwarding = true
manager.netstack = false
decoder := createTCPDecoder(8080)
result := manager.shouldForward(decoder, tt.dstIP)
require.Equal(t, tt.expected, result, tt.description)
})
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Configure manager

View File

@@ -1,8 +1,7 @@
package forwarder
import (
"net"
"strconv"
"fmt"
"sync/atomic"
wgdevice "golang.zx2c4.com/wireguard/device"
@@ -55,23 +54,16 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress {
func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
var written int
for _, pkt := range pkts.AsSlice() {
netHeader := header.IPv4(pkt.NetworkHeader().View().AsSlice())
data := stack.PayloadSince(pkt.NetworkHeader())
if data == nil {
continue
}
raw := pkt.NetworkHeader().View().AsSlice()
if len(raw) == 0 {
continue
}
var address tcpip.Address
if raw[0]>>4 == 6 {
address = header.IPv6(raw).DestinationAddress()
} else {
address = header.IPv4(raw).DestinationAddress()
}
pktBytes := data.AsSlice()
address := netHeader.DestinationAddress()
if err := e.device.CreateOutboundPacket(pktBytes, address.AsSlice()); err != nil {
e.logger.Error1("CreateOutboundPacket: %v", err)
continue
@@ -122,7 +114,5 @@ type epID stack.TransportEndpointID
func (i epID) String() string {
// src and remote is swapped
return net.JoinHostPort(i.RemoteAddress.String(), strconv.Itoa(int(i.RemotePort))) +
" → " +
net.JoinHostPort(i.LocalAddress.String(), strconv.Itoa(int(i.LocalPort)))
return fmt.Sprintf("%s:%d → %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
}

View File

@@ -14,7 +14,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
@@ -37,31 +36,25 @@ type Forwarder struct {
logger *nblog.Logger
flowLogger nftypes.FlowLogger
// ruleIdMap is used to store the rule ID for a given connection
ruleIdMap sync.Map
stack *stack.Stack
endpoint *endpoint
udpForwarder *udpForwarder
ctx context.Context
cancel context.CancelFunc
ip tcpip.Address
ipv6 tcpip.Address
netstack bool
hasRawICMPAccess bool
hasRawICMPv6Access bool
pingSemaphore chan struct{}
ruleIdMap sync.Map
stack *stack.Stack
endpoint *endpoint
udpForwarder *udpForwarder
ctx context.Context
cancel context.CancelFunc
ip tcpip.Address
netstack bool
hasRawICMPAccess bool
pingSemaphore chan struct{}
}
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocol,
ipv6.NewProtocol,
},
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{
tcp.NewProtocol,
udp.NewProtocol,
icmp.NewProtocol4,
icmp.NewProtocol6,
},
HandleLocal: false,
})
@@ -80,7 +73,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFrom4(iface.Address().IP.As4()),
Address: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
PrefixLen: iface.Address().Network.Bits(),
},
}
@@ -89,19 +82,6 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
return nil, fmt.Errorf("failed to add protocol address: %s", err)
}
if v6 := iface.Address().IPv6; v6.IsValid() {
v6Addr := tcpip.ProtocolAddress{
Protocol: ipv6.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFrom16(v6.As16()),
PrefixLen: iface.Address().IPv6Net.Bits(),
},
}
if err := s.AddProtocolAddress(nicID, v6Addr, stack.AddressProperties{}); err != nil {
return nil, fmt.Errorf("add IPv6 protocol address: %s", err)
}
}
defaultSubnet, err := tcpip.NewSubnet(
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
@@ -110,14 +90,6 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
return nil, fmt.Errorf("creating default subnet: %w", err)
}
defaultSubnetV6, err := tcpip.NewSubnet(
tcpip.AddrFrom16([16]byte{}),
tcpip.MaskFromBytes(make([]byte, 16)),
)
if err != nil {
return nil, fmt.Errorf("creating default v6 subnet: %w", err)
}
if err := s.SetPromiscuousMode(nicID, true); err != nil {
return nil, fmt.Errorf("set promiscuous mode: %s", err)
}
@@ -126,8 +98,10 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
}
s.SetRouteTable([]tcpip.Route{
{Destination: defaultSubnet, NIC: nicID},
{Destination: defaultSubnetV6, NIC: nicID},
{
Destination: defaultSubnet,
NIC: nicID,
},
})
ctx, cancel := context.WithCancel(context.Background())
@@ -140,8 +114,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
ctx: ctx,
cancel: cancel,
netstack: netstack,
ip: tcpip.AddrFrom4(iface.Address().IP.As4()),
ipv6: addrFromNetipAddr(iface.Address().IPv6),
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
pingSemaphore: make(chan struct{}, 3),
}
@@ -158,10 +131,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
udpForwarder := udp.NewForwarder(s, f.handleUDP)
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
// ICMP is handled directly in InjectIncomingPacket, bypassing gVisor's
// network layer. This avoids duplicate echo replies (v4) and the v6
// auto-reply bug where gVisor responds at the network layer before
// our transport handler fires.
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP)
f.checkICMPCapability()
@@ -180,30 +150,8 @@ func (f *Forwarder) SetCapture(pc PacketCapture) {
}
func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
if len(payload) == 0 {
return fmt.Errorf("empty packet")
}
var protoNum tcpip.NetworkProtocolNumber
switch payload[0] >> 4 {
case 4:
if len(payload) < header.IPv4MinimumSize {
return fmt.Errorf("IPv4 packet too small: %d bytes", len(payload))
}
if f.handleICMPDirect(payload) {
return nil
}
protoNum = ipv4.ProtocolNumber
case 6:
if len(payload) < header.IPv6MinimumSize {
return fmt.Errorf("IPv6 packet too small: %d bytes", len(payload))
}
if f.handleICMPDirect(payload) {
return nil
}
protoNum = ipv6.ProtocolNumber
default:
return fmt.Errorf("unknown IP version: %d", payload[0]>>4)
if len(payload) < header.IPv4MinimumSize {
return fmt.Errorf("packet too small: %d bytes", len(payload))
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -212,160 +160,11 @@ func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
defer pkt.DecRef()
if f.endpoint.dispatcher != nil {
f.endpoint.dispatcher.DeliverNetworkPacket(protoNum, pkt)
f.endpoint.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
}
return nil
}
// handleICMPDirect intercepts ICMP packets from raw IP payloads before they
// enter gVisor. It synthesizes the TransportEndpointID and PacketBuffer that
// the existing handlers expect, then dispatches to handleICMP/handleICMPv6.
// This bypasses gVisor's network layer which causes duplicate v4 echo replies
// and auto-replies to all v6 echo requests in promiscuous mode.
//
// Unlike gVisor's network layer, this does not validate ICMP checksums or
// reassemble IP fragments. Fragmented ICMP packets fall through to gVisor.
func parseICMPv4(payload []byte) (ipHdrLen, icmpLen int, src, dst tcpip.Address, ok bool) {
if len(payload) < header.IPv4MinimumSize {
return 0, 0, src, dst, false
}
ip := header.IPv4(payload)
if ip.Protocol() != uint8(header.ICMPv4ProtocolNumber) {
return 0, 0, src, dst, false
}
if ip.FragmentOffset() != 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 {
return 0, 0, src, dst, false
}
ipHdrLen = int(ip.HeaderLength())
totalLen := int(ip.TotalLength())
if ipHdrLen < header.IPv4MinimumSize || ipHdrLen > totalLen || totalLen > len(payload) {
return 0, 0, src, dst, false
}
icmpLen = totalLen - ipHdrLen
if icmpLen < header.ICMPv4MinimumSize {
return 0, 0, src, dst, false
}
return ipHdrLen, icmpLen, ip.SourceAddress(), ip.DestinationAddress(), true
}
func parseICMPv6(payload []byte) (ipHdrLen, icmpLen int, src, dst tcpip.Address, ok bool) {
if len(payload) < header.IPv6MinimumSize {
return 0, 0, src, dst, false
}
ip := header.IPv6(payload)
declaredLen := int(ip.PayloadLength())
hdrEnd := header.IPv6MinimumSize + declaredLen
if hdrEnd > len(payload) {
return 0, 0, src, dst, false
}
icmpStart, ok := skipIPv6ExtensionsToICMPv6(payload, ip.NextHeader(), hdrEnd)
if !ok {
return 0, 0, src, dst, false
}
icmpLen = hdrEnd - icmpStart
if icmpLen < header.ICMPv6MinimumSize {
return 0, 0, src, dst, false
}
return icmpStart, icmpLen, ip.SourceAddress(), ip.DestinationAddress(), true
}
// skipIPv6ExtensionsToICMPv6 walks the IPv6 extension-header chain starting
// after the fixed header. It advances past Hop-by-Hop, Routing, and
// Destination Options headers (which share the NextHeader+ExtLen+6+ExtLen*8
// layout) and returns the offset of the ICMPv6 payload. Fragment, ESP, AH,
// and unknown identifiers are reported as not handleable so the caller can
// defer to gVisor.
func skipIPv6ExtensionsToICMPv6(payload []byte, next uint8, hdrEnd int) (int, bool) {
off := header.IPv6MinimumSize
for {
if next == uint8(header.ICMPv6ProtocolNumber) {
return off, true
}
if !isWalkableIPv6ExtHdr(next) {
return 0, false
}
newOff, newNext, ok := advanceIPv6ExtHdr(payload, off, hdrEnd)
if !ok {
return 0, false
}
off = newOff
next = newNext
}
}
func isWalkableIPv6ExtHdr(id uint8) bool {
switch id {
case uint8(header.IPv6HopByHopOptionsExtHdrIdentifier),
uint8(header.IPv6RoutingExtHdrIdentifier),
uint8(header.IPv6DestinationOptionsExtHdrIdentifier):
return true
}
return false
}
func advanceIPv6ExtHdr(payload []byte, off, hdrEnd int) (int, uint8, bool) {
if off+8 > hdrEnd {
return 0, 0, false
}
extLen := (int(payload[off+1]) + 1) * 8
if off+extLen > hdrEnd {
return 0, 0, false
}
return off + extLen, payload[off], true
}
func (f *Forwarder) handleICMPDirect(payload []byte) bool {
if len(payload) == 0 {
return false
}
var (
ipHdrLen int
icmpLen int
srcAddr tcpip.Address
dstAddr tcpip.Address
ok bool
)
version := payload[0] >> 4
switch version {
case 4:
ipHdrLen, icmpLen, srcAddr, dstAddr, ok = parseICMPv4(payload)
case 6:
ipHdrLen, icmpLen, srcAddr, dstAddr, ok = parseICMPv6(payload)
}
if !ok {
return false
}
// Let gVisor handle ICMP destined for our own addresses natively.
// Its network-layer auto-reply is correct and efficient for local traffic.
if f.ip.Equal(dstAddr) || f.ipv6.Equal(dstAddr) {
return false
}
id := stack.TransportEndpointID{
LocalAddress: dstAddr,
RemoteAddress: srcAddr,
}
// Trim the buffer to the IP-declared length so gVisor doesn't see padding.
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(payload[:ipHdrLen+icmpLen]),
})
defer pkt.DecRef()
if _, ok := pkt.NetworkHeader().Consume(ipHdrLen); !ok {
return false
}
if _, ok := pkt.TransportHeader().Consume(icmpLen); !ok {
return false
}
if version == 6 {
return f.handleICMPv6(id, pkt)
}
return f.handleICMP(id, pkt)
}
// Stop gracefully shuts down the forwarder
func (f *Forwarder) Stop() {
f.cancel()
@@ -378,14 +177,11 @@ func (f *Forwarder) Stop() {
f.stack.Wait()
}
func (f *Forwarder) determineDialAddr(addr tcpip.Address) netip.Addr {
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
if f.netstack && f.ip.Equal(addr) {
return netip.AddrFrom4([4]byte{127, 0, 0, 1})
return net.IPv4(127, 0, 0, 1)
}
if f.netstack && f.ipv6.Equal(addr) {
return netip.IPv6Loopback()
}
return addrToNetipAddr(addr)
return addr.AsSlice()
}
func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) {
@@ -419,50 +215,23 @@ func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKe
}
}
// addrFromNetipAddr converts a netip.Addr to a gvisor tcpip.Address without allocating.
func addrFromNetipAddr(addr netip.Addr) tcpip.Address {
if !addr.IsValid() {
return tcpip.Address{}
}
if addr.Is4() {
return tcpip.AddrFrom4(addr.As4())
}
return tcpip.AddrFrom16(addr.As16())
}
// addrToNetipAddr converts a gvisor tcpip.Address to netip.Addr without allocating.
func addrToNetipAddr(addr tcpip.Address) netip.Addr {
switch addr.Len() {
case 4:
return netip.AddrFrom4(addr.As4())
case 16:
return netip.AddrFrom16(addr.As16())
default:
return netip.Addr{}
}
}
// checkICMPCapability tests whether we have raw ICMP socket access at startup.
func (f *Forwarder) checkICMPCapability() {
f.hasRawICMPAccess = probeRawICMP("ip4:icmp", "0.0.0.0", f.logger)
f.hasRawICMPv6Access = probeRawICMP("ip6:ipv6-icmp", "::", f.logger)
}
func probeRawICMP(network, addr string, logger *nblog.Logger) bool {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
lc := net.ListenConfig{}
conn, err := lc.ListenPacket(ctx, network, addr)
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
if err != nil {
logger.Debug1("forwarder: no raw %s socket access, will use ping binary fallback", network)
return false
f.hasRawICMPAccess = false
f.logger.Debug("forwarder: No raw ICMP socket access, will use ping binary fallback")
return
}
if err := conn.Close(); err != nil {
logger.Debug2("forwarder: failed to close %s capability test socket: %v", network, err)
f.logger.Debug1("forwarder: Failed to close ICMP capability test socket: %v", err)
}
logger.Debug1("forwarder: raw %s socket access available", network)
return true
f.hasRawICMPAccess = true
f.logger.Debug("forwarder: Raw ICMP socket access available")
}

View File

@@ -1,162 +0,0 @@
package forwarder
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
const echoRequestSize = 8
func makeIPv6(t *testing.T, src, dst netip.Addr, nextHdr uint8, payload []byte) []byte {
t.Helper()
buf := make([]byte, header.IPv6MinimumSize+len(payload))
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(len(payload)),
TransportProtocol: 0, // overwritten below to allow any value
HopLimit: 64,
SrcAddr: tcpipAddrFromNetip(src),
DstAddr: tcpipAddrFromNetip(dst),
})
buf[6] = nextHdr
copy(buf[header.IPv6MinimumSize:], payload)
return buf
}
func tcpipAddrFromNetip(a netip.Addr) tcpip.Address {
b := a.As16()
return tcpip.AddrFrom16(b)
}
func echoRequest() []byte {
icmp := make([]byte, echoRequestSize)
icmp[0] = uint8(header.ICMPv6EchoRequest)
return icmp
}
// extHdr builds a generic IPv6 extension header (HBH/Routing/DestOpts) of the
// given total octet length (must be multiple of 8, >= 8) with the given next
// header.
func extHdr(t *testing.T, next uint8, totalLen int) []byte {
t.Helper()
require.GreaterOrEqual(t, totalLen, 8)
require.Equal(t, 0, totalLen%8)
buf := make([]byte, totalLen)
buf[0] = next
buf[1] = uint8(totalLen/8 - 1)
return buf
}
func TestParseICMPv6_NoExtensions(t *testing.T) {
src := netip.MustParseAddr("fd00::1")
dst := netip.MustParseAddr("fd00::2")
pkt := makeIPv6(t, src, dst, uint8(header.ICMPv6ProtocolNumber), echoRequest())
off, icmpLen, _, _, ok := parseICMPv6(pkt)
require.True(t, ok)
assert.Equal(t, header.IPv6MinimumSize, off)
assert.Equal(t, echoRequestSize, icmpLen)
}
func TestParseICMPv6_SingleExtension(t *testing.T) {
src := netip.MustParseAddr("fd00::1")
dst := netip.MustParseAddr("fd00::2")
hbh := extHdr(t, uint8(header.ICMPv6ProtocolNumber), 8)
payload := append([]byte{}, hbh...)
payload = append(payload, echoRequest()...)
pkt := makeIPv6(t, src, dst, uint8(header.IPv6HopByHopOptionsExtHdrIdentifier), payload)
off, icmpLen, _, _, ok := parseICMPv6(pkt)
require.True(t, ok)
assert.Equal(t, header.IPv6MinimumSize+8, off)
assert.Equal(t, echoRequestSize, icmpLen)
}
func TestParseICMPv6_ChainedExtensions(t *testing.T) {
src := netip.MustParseAddr("fd00::1")
dst := netip.MustParseAddr("fd00::2")
dest := extHdr(t, uint8(header.ICMPv6ProtocolNumber), 16)
rt := extHdr(t, uint8(header.IPv6DestinationOptionsExtHdrIdentifier), 8)
hbh := extHdr(t, uint8(header.IPv6RoutingExtHdrIdentifier), 8)
payload := append(append(append([]byte{}, hbh...), rt...), dest...)
payload = append(payload, echoRequest()...)
pkt := makeIPv6(t, src, dst, uint8(header.IPv6HopByHopOptionsExtHdrIdentifier), payload)
off, icmpLen, _, _, ok := parseICMPv6(pkt)
require.True(t, ok)
assert.Equal(t, header.IPv6MinimumSize+8+8+16, off)
assert.Equal(t, echoRequestSize, icmpLen)
}
func TestParseICMPv6_FragmentDefersToGVisor(t *testing.T) {
src := netip.MustParseAddr("fd00::1")
dst := netip.MustParseAddr("fd00::2")
pkt := makeIPv6(t, src, dst, uint8(header.IPv6FragmentExtHdrIdentifier), make([]byte, 8))
_, _, _, _, ok := parseICMPv6(pkt)
assert.False(t, ok)
}
func TestParseICMPv6_TruncatedExtension(t *testing.T) {
src := netip.MustParseAddr("fd00::1")
dst := netip.MustParseAddr("fd00::2")
// Extension claims 16 bytes but only 8 remain after the IP header.
hbh := []byte{uint8(header.ICMPv6ProtocolNumber), 1, 0, 0, 0, 0, 0, 0}
pkt := makeIPv6(t, src, dst, uint8(header.IPv6HopByHopOptionsExtHdrIdentifier), hbh)
_, _, _, _, ok := parseICMPv6(pkt)
assert.False(t, ok)
}
func TestParseICMPv6_TruncatedICMPPayload(t *testing.T) {
src := netip.MustParseAddr("fd00::1")
dst := netip.MustParseAddr("fd00::2")
// PayloadLength claims 8 bytes of ICMPv6 but the buffer only holds 4.
pkt := makeIPv6(t, src, dst, uint8(header.ICMPv6ProtocolNumber), make([]byte, 8))
pkt = pkt[:header.IPv6MinimumSize+4]
_, _, _, _, ok := parseICMPv6(pkt)
assert.False(t, ok)
}
func TestParseICMPv4_RejectsShortIHL(t *testing.T) {
pkt := make([]byte, 28)
pkt[0] = 0x44 // version 4, IHL 4 (16 bytes - below minimum)
pkt[9] = uint8(header.ICMPv4ProtocolNumber)
header.IPv4(pkt).SetTotalLength(28)
_, _, _, _, ok := parseICMPv4(pkt)
assert.False(t, ok)
}
func TestParseICMPv4_RejectsTotalLenOverBuffer(t *testing.T) {
pkt := make([]byte, header.IPv4MinimumSize+header.ICMPv4MinimumSize)
ip := header.IPv4(pkt)
ip.Encode(&header.IPv4Fields{
TotalLength: uint16(len(pkt) + 16),
Protocol: uint8(header.ICMPv4ProtocolNumber),
TTL: 64,
})
_, _, _, _, ok := parseICMPv4(pkt)
assert.False(t, ok)
}
func TestParseICMPv4_RejectsFragment(t *testing.T) {
pkt := make([]byte, header.IPv4MinimumSize+header.ICMPv4MinimumSize)
ip := header.IPv4(pkt)
ip.Encode(&header.IPv4Fields{
TotalLength: uint16(len(pkt)),
Protocol: uint8(header.ICMPv4ProtocolNumber),
TTL: 64,
Flags: header.IPv4FlagMoreFragments,
})
_, _, _, _, ok := parseICMPv4(pkt)
assert.False(t, ok)
}

View File

@@ -35,7 +35,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt *stack.PacketBu
}
icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice()
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), false, 100*time.Millisecond)
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 100*time.Millisecond)
if err != nil {
f.logger.Error2("forwarder: Failed to forward ICMP packet for %v: %v", epID(id), err)
return true
@@ -58,7 +58,7 @@ func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointI
defer func() { <-f.pingSemaphore }()
if f.hasRawICMPAccess {
f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes, false)
f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
} else {
f.handleICMPViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
}
@@ -72,23 +72,18 @@ func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointI
// forwardICMPPacket creates a raw ICMP socket and sends the packet, returning the connection.
// The caller is responsible for closing the returned connection.
func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, v6 bool, timeout time.Duration) (net.PacketConn, error) {
func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, timeout time.Duration) (net.PacketConn, error) {
ctx, cancel := context.WithTimeout(f.ctx, timeout)
defer cancel()
network, listenAddr := "ip4:icmp", "0.0.0.0"
if v6 {
network, listenAddr = "ip6:ipv6-icmp", "::"
}
lc := net.ListenConfig{}
conn, err := lc.ListenPacket(ctx, network, listenAddr)
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
if err != nil {
return nil, fmt.Errorf("create ICMP socket: %w", err)
}
dstIP := f.determineDialAddr(id.LocalAddress)
dst := &net.IPAddr{IP: dstIP.AsSlice()}
dst := &net.IPAddr{IP: dstIP}
if _, err = conn.WriteTo(payload, dst); err != nil {
if closeErr := conn.Close(); closeErr != nil {
@@ -103,11 +98,11 @@ func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []by
return conn, nil
}
// handleICMPViaSocket handles ICMP echo requests using raw sockets for both v4 and v6.
func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int, v6 bool) {
// handleICMPViaSocket handles ICMP echo requests using raw sockets.
func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
sendTime := time.Now()
conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, v6, 5*time.Second)
conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, 5*time.Second)
if err != nil {
f.logger.Error2("forwarder: Failed to send ICMP packet for %v: %v", epID(id), err)
return
@@ -118,20 +113,16 @@ func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndp
}
}()
txBytes := f.handleEchoResponse(conn, id, v6)
txBytes := f.handleEchoResponse(conn, id)
rtt := time.Since(sendTime).Round(10 * time.Microsecond)
proto := "ICMP"
if v6 {
proto = "ICMPv6"
}
f.logger.Trace5("forwarder: Forwarded %s echo reply %v type %v code %v (rtt=%v, raw socket)",
proto, epID(id), icmpType, icmpCode, rtt)
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, raw socket)",
epID(id), icmpType, icmpCode, rtt)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
}
func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID, v6 bool) int {
func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID) int {
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
f.logger.Error1("forwarder: Failed to set read deadline for ICMP response: %v", err)
return 0
@@ -146,19 +137,6 @@ func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEn
return 0
}
if v6 {
// Recompute checksum: the raw socket response has a checksum computed
// over the real endpoint addresses, but we inject with overlay addresses.
icmpHdr := header.ICMPv6(response[:n])
icmpHdr.SetChecksum(0)
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
Src: id.LocalAddress,
Dst: id.RemoteAddress,
}))
return f.injectICMPv6Reply(id, response[:n])
}
return f.injectICMPReply(id, response[:n])
}
@@ -172,23 +150,19 @@ func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.T
txPackets = 1
}
srcIp := addrToNetipAddr(id.RemoteAddress)
dstIp := addrToNetipAddr(id.LocalAddress)
proto := nftypes.ICMP
if srcIp.Is6() {
proto = nftypes.ICMPv6
}
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
fields := nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: proto,
SourceIP: srcIp,
DestIP: dstIp,
ICMPType: icmpType,
ICMPCode: icmpCode,
Protocol: nftypes.ICMP,
// TODO: handle ipv6
SourceIP: srcIp,
DestIP: dstIp,
ICMPType: icmpType,
ICMPCode: icmpCode,
RxBytes: rxBytes,
TxBytes: txBytes,
@@ -235,164 +209,26 @@ func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpoi
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
}
// handleICMPv6 handles ICMPv6 packets from the network stack.
func (f *Forwarder) handleICMPv6(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
icmpHdr := header.ICMPv6(pkt.TransportHeader().View().AsSlice())
flowID := uuid.New()
f.sendICMPEvent(nftypes.TypeStart, flowID, id, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 0, 0)
if icmpHdr.Type() == header.ICMPv6EchoRequest {
return f.handleICMPv6Echo(flowID, id, pkt, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()))
}
// For non-echo types (Destination Unreachable, Packet Too Big, etc), forward without waiting
if !f.hasRawICMPv6Access {
f.logger.Debug2("forwarder: Cannot handle ICMPv6 type %v without raw socket access for %v", icmpHdr.Type(), epID(id))
return false
}
icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice()
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), true, 100*time.Millisecond)
if err != nil {
f.logger.Error2("forwarder: Failed to forward ICMPv6 packet for %v: %v", epID(id), err)
return true
}
if err := conn.Close(); err != nil {
f.logger.Debug1("forwarder: Failed to close ICMPv6 socket: %v", err)
}
return true
}
// handleICMPv6Echo handles ICMPv6 echo requests via raw socket or ping binary fallback.
func (f *Forwarder) handleICMPv6Echo(flowID uuid.UUID, id stack.TransportEndpointID, pkt *stack.PacketBuffer, icmpType, icmpCode uint8) bool {
select {
case f.pingSemaphore <- struct{}{}:
icmpData := stack.PayloadSince(pkt.TransportHeader()).ToSlice()
rxBytes := pkt.Size()
go func() {
defer func() { <-f.pingSemaphore }()
if f.hasRawICMPv6Access {
f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes, true)
} else {
f.handleICMPv6ViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
}
}()
default:
f.logger.Debug3("forwarder: ICMPv6 rate limit exceeded for %v type %v code %v", epID(id), icmpType, icmpCode)
}
return true
}
// handleICMPv6ViaPing uses the system ping6 binary for ICMPv6 echo.
func (f *Forwarder) handleICMPv6ViaPing(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
defer cancel()
dstIP := f.determineDialAddr(id.LocalAddress)
cmd := buildPingCommand(ctx, dstIP, 5*time.Second)
pingStart := time.Now()
if err := cmd.Run(); err != nil {
f.logger.Warn4("forwarder: Ping6 failed for %v type %v code %v: %v", epID(id), icmpType, icmpCode, err)
return
}
rtt := time.Since(pingStart).Round(10 * time.Microsecond)
f.logger.Trace3("forwarder: Forwarded ICMPv6 echo request %v type %v code %v",
epID(id), icmpType, icmpCode)
txBytes := f.synthesizeICMPv6EchoReply(id, icmpData)
f.logger.Trace4("forwarder: Forwarded ICMPv6 echo reply %v type %v code %v (rtt=%v, ping binary)",
epID(id), icmpType, icmpCode, rtt)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
}
// synthesizeICMPv6EchoReply creates an ICMPv6 echo reply and injects it back.
func (f *Forwarder) synthesizeICMPv6EchoReply(id stack.TransportEndpointID, icmpData []byte) int {
replyICMP := make([]byte, len(icmpData))
copy(replyICMP, icmpData)
replyHdr := header.ICMPv6(replyICMP)
replyHdr.SetType(header.ICMPv6EchoReply)
replyHdr.SetChecksum(0)
// ICMPv6Checksum computes the pseudo-header internally from Src/Dst.
// Header contains the full ICMP message, so PayloadCsum/PayloadLen are zero.
replyHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: replyHdr,
Src: id.LocalAddress,
Dst: id.RemoteAddress,
}))
return f.injectICMPv6Reply(id, replyICMP)
}
// injectICMPv6Reply wraps an ICMPv6 payload in an IPv6 header and sends to the peer.
func (f *Forwarder) injectICMPv6Reply(id stack.TransportEndpointID, icmpPayload []byte) int {
ipHdr := make([]byte, header.IPv6MinimumSize)
ip := header.IPv6(ipHdr)
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(len(icmpPayload)),
TransportProtocol: header.ICMPv6ProtocolNumber,
HopLimit: 64,
SrcAddr: id.LocalAddress,
DstAddr: id.RemoteAddress,
})
fullPacket := make([]byte, 0, len(ipHdr)+len(icmpPayload))
fullPacket = append(fullPacket, ipHdr...)
fullPacket = append(fullPacket, icmpPayload...)
if err := f.endpoint.device.CreateOutboundPacket(fullPacket, id.RemoteAddress.AsSlice()); err != nil {
f.logger.Error1("forwarder: Failed to send ICMPv6 reply to peer: %v", err)
return 0
}
return len(fullPacket)
}
const (
pingBin = "ping"
ping6Bin = "ping6"
)
// buildPingCommand creates a platform-specific ping command.
// Most platforms auto-detect IPv6 from raw addresses. macOS/iOS/OpenBSD require ping6.
func buildPingCommand(ctx context.Context, target netip.Addr, timeout time.Duration) *exec.Cmd {
func buildPingCommand(ctx context.Context, target net.IP, timeout time.Duration) *exec.Cmd {
timeoutSec := int(timeout.Seconds())
if timeoutSec < 1 {
timeoutSec = 1
}
isV6 := target.Is6()
timeoutStr := fmt.Sprintf("%d", timeoutSec)
switch runtime.GOOS {
case "linux", "android":
return exec.CommandContext(ctx, pingBin, "-c", "1", "-W", timeoutStr, "-q", target.String())
return exec.CommandContext(ctx, "ping", "-c", "1", "-W", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
case "darwin", "ios":
bin := pingBin
if isV6 {
bin = ping6Bin
}
return exec.CommandContext(ctx, bin, "-c", "1", "-t", timeoutStr, "-q", target.String())
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
case "freebsd":
return exec.CommandContext(ctx, pingBin, "-c", "1", "-t", timeoutStr, target.String())
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), target.String())
case "openbsd", "netbsd":
bin := pingBin
if isV6 {
bin = ping6Bin
}
return exec.CommandContext(ctx, bin, "-c", "1", "-w", timeoutStr, target.String())
return exec.CommandContext(ctx, "ping", "-c", "1", "-w", fmt.Sprintf("%d", timeoutSec), target.String())
case "windows":
return exec.CommandContext(ctx, pingBin, "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String())
return exec.CommandContext(ctx, "ping", "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String())
default:
return exec.CommandContext(ctx, pingBin, "-c", "1", target.String())
return exec.CommandContext(ctx, "ping", "-c", "1", target.String())
}
}

View File

@@ -2,9 +2,10 @@ package forwarder
import (
"context"
"fmt"
"io"
"net"
"strconv"
"net/netip"
"sync"
"github.com/google/uuid"
@@ -32,7 +33,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
}
}()
dialAddr := net.JoinHostPort(f.determineDialAddr(id.LocalAddress).String(), strconv.Itoa(int(id.LocalPort)))
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
if err != nil {
@@ -132,14 +133,15 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
}
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
srcIp := addrToNetipAddr(id.RemoteAddress)
dstIp := addrToNetipAddr(id.LocalAddress)
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
fields := nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.TCP,
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.TCP,
// TODO: handle ipv6
SourceIP: srcIp,
DestIP: dstIp,
SourcePort: id.RemotePort,

View File

@@ -6,7 +6,7 @@ import (
"fmt"
"io"
"net"
"strconv"
"net/netip"
"sync"
"sync/atomic"
"time"
@@ -158,7 +158,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
}
}()
dstAddr := net.JoinHostPort(f.determineDialAddr(id.LocalAddress).String(), strconv.Itoa(int(id.LocalPort)))
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
if err != nil {
f.logger.Debug2("forwarder: UDP dial error for %v: %v", epID(id), err)
@@ -276,14 +276,15 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
// sendUDPEvent stores flow events for UDP connections
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
srcIp := addrToNetipAddr(id.RemoteAddress)
dstIp := addrToNetipAddr(id.LocalAddress)
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
fields := nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.UDP,
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.UDP,
// TODO: handle ipv6
SourceIP: srcIp,
DestIP: dstIp,
SourcePort: id.RemotePort,

View File

@@ -13,6 +13,7 @@ const (
ipv4HeaderMinLen = 20
ipv4ProtoOffset = 9
ipv4FlagsOffset = 6
ipv4DstOffset = 16
ipProtoUDP = 17
ipProtoTCP = 6
ipv4FragOffMask = 0x1fff

View File

@@ -4,32 +4,89 @@ import (
"fmt"
"net"
"net/netip"
"sync/atomic"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
)
// localIPSnapshot is an immutable snapshot of local IP addresses, swapped
// atomically so reads are lock-free.
type localIPSnapshot struct {
ips map[netip.Addr]struct{}
type localIPManager struct {
mu sync.RWMutex
// fixed-size high array for upper byte of a IPv4 address
ipv4Bitmap [256]*ipv4LowBitmap
}
type localIPManager struct {
snapshot atomic.Pointer[localIPSnapshot]
// ipv4LowBitmap is a map for the low 16 bits of a IPv4 address
type ipv4LowBitmap struct {
bitmap [8192]uint32
}
func newLocalIPManager() *localIPManager {
m := &localIPManager{}
m.snapshot.Store(&localIPSnapshot{
ips: make(map[netip.Addr]struct{}),
})
return m
return &localIPManager{}
}
func processInterface(iface net.Interface, ips map[netip.Addr]struct{}, addresses *[]netip.Addr) {
func (m *localIPManager) setBitmapBit(ip net.IP) {
ipv4 := ip.To4()
if ipv4 == nil {
return
}
high := uint16(ipv4[0])
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
index := low / 32
bit := low % 32
if m.ipv4Bitmap[high] == nil {
m.ipv4Bitmap[high] = &ipv4LowBitmap{}
}
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
}
func (m *localIPManager) setBitInBitmap(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
if !ip.Is4() {
return
}
ipv4 := ip.AsSlice()
high := uint16(ipv4[0])
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
if bitmap[high] == nil {
bitmap[high] = &ipv4LowBitmap{}
}
index := low / 32
bit := low % 32
bitmap[high].bitmap[index] |= 1 << bit
if _, exists := ipv4Set[ip]; !exists {
ipv4Set[ip] = struct{}{}
*ipv4Addresses = append(*ipv4Addresses, ip)
}
}
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
high := uint16(ip[0])
low := (uint16(ip[1]) << 8) | (uint16(ip[2]) << 4) | uint16(ip[3])
if m.ipv4Bitmap[high] == nil {
return false
}
index := low / 32
bit := low % 32
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
}
func (m *localIPManager) processIP(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) error {
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
return nil
}
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
addrs, err := iface.Addrs()
if err != nil {
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
@@ -47,19 +104,18 @@ func processInterface(iface net.Interface, ips map[netip.Addr]struct{}, addresse
continue
}
parsed, ok := netip.AddrFromSlice(ip)
addr, ok := netip.AddrFromSlice(ip)
if !ok {
log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name)
continue
}
parsed = parsed.Unmap()
ips[parsed] = struct{}{}
*addresses = append(*addresses, parsed)
if err := m.processIP(addr.Unmap(), bitmap, ipv4Set, ipv4Addresses); err != nil {
log.Debugf("process IP failed: %v", err)
}
}
}
// UpdateLocalIPs rebuilds the local IP snapshot and swaps it in atomically.
func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
defer func() {
if r := recover(); r != nil {
@@ -67,20 +123,20 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
}
}()
ips := make(map[netip.Addr]struct{})
var addresses []netip.Addr
var newIPv4Bitmap [256]*ipv4LowBitmap
ipv4Set := make(map[netip.Addr]struct{})
var ipv4Addresses []netip.Addr
// loopback
ips[netip.AddrFrom4([4]byte{127, 0, 0, 1})] = struct{}{}
ips[netip.IPv6Loopback()] = struct{}{}
// 127.0.0.0/8
newIPv4Bitmap[127] = &ipv4LowBitmap{}
for i := 0; i < 8192; i++ {
// #nosec G602 -- bitmap is defined as [8192]uint32, loop range is correct
newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF
}
if iface != nil {
ip := iface.Address().IP
ips[ip] = struct{}{}
addresses = append(addresses, ip)
if v6 := iface.Address().IPv6; v6.IsValid() {
ips[v6] = struct{}{}
addresses = append(addresses, v6)
if err := m.processIP(iface.Address().IP, &newIPv4Bitmap, ipv4Set, &ipv4Addresses); err != nil {
return err
}
}
@@ -91,24 +147,25 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
// TODO: filter out down interfaces (net.FlagUp). Also handle the reverse
// case where an interface comes up between refreshes.
for _, intf := range interfaces {
processInterface(intf, ips, &addresses)
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
}
}
m.snapshot.Store(&localIPSnapshot{ips: ips})
m.mu.Lock()
m.ipv4Bitmap = newIPv4Bitmap
m.mu.Unlock()
log.Debugf("Local IP addresses: %v", addresses)
log.Debugf("Local IPv4 addresses: %v", ipv4Addresses)
return nil
}
// IsLocalIP checks if the given IP is a local address. Lock-free on the read path.
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
s := m.snapshot.Load()
if ip.Is4() && ip.As4()[0] == 127 {
return true
if !ip.Is4() {
return false
}
_, found := s.ips[ip]
return found
m.mu.RLock()
defer m.mu.RUnlock()
return m.checkBitmapBit(ip.AsSlice())
}

View File

@@ -1,72 +0,0 @@
package uspfilter
import (
"net/netip"
"testing"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func setupManager(b *testing.B) *localIPManager {
b.Helper()
m := newLocalIPManager()
mock := &IFaceMock{
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/16"),
IPv6: netip.MustParseAddr("fd00::1"),
IPv6Net: netip.MustParsePrefix("fd00::/64"),
}
},
}
if err := m.UpdateLocalIPs(mock); err != nil {
b.Fatalf("UpdateLocalIPs: %v", err)
}
return m
}
func BenchmarkIsLocalIP_v4_hit(b *testing.B) {
m := setupManager(b)
ip := netip.MustParseAddr("100.64.0.1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
m.IsLocalIP(ip)
}
}
func BenchmarkIsLocalIP_v4_miss(b *testing.B) {
m := setupManager(b)
ip := netip.MustParseAddr("8.8.8.8")
b.ResetTimer()
for i := 0; i < b.N; i++ {
m.IsLocalIP(ip)
}
}
func BenchmarkIsLocalIP_v6_hit(b *testing.B) {
m := setupManager(b)
ip := netip.MustParseAddr("fd00::1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
m.IsLocalIP(ip)
}
}
func BenchmarkIsLocalIP_v6_miss(b *testing.B) {
m := setupManager(b)
ip := netip.MustParseAddr("2001:db8::1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
m.IsLocalIP(ip)
}
}
func BenchmarkIsLocalIP_loopback(b *testing.B) {
m := setupManager(b)
ip := netip.MustParseAddr("127.0.0.1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
m.IsLocalIP(ip)
}
}

View File

@@ -72,45 +72,14 @@ func TestLocalIPManager(t *testing.T) {
expected: false,
},
{
name: "IPv6 address matches",
name: "IPv6 address",
setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/16"),
IPv6: netip.MustParseAddr("fd00::1"),
IPv6Net: netip.MustParsePrefix("fd00::/64"),
},
testIP: netip.MustParseAddr("fd00::1"),
expected: true,
},
{
name: "IPv6 address does not match",
setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/16"),
IPv6: netip.MustParseAddr("fd00::1"),
IPv6Net: netip.MustParsePrefix("fd00::/64"),
},
testIP: netip.MustParseAddr("fd00::99"),
expected: false,
},
{
name: "No aliasing between similar IPs",
setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("192.168.1.1"),
IP: netip.MustParseAddr("fe80::1"),
Network: netip.MustParsePrefix("192.168.1.0/24"),
},
testIP: netip.MustParseAddr("192.168.0.17"),
testIP: netip.MustParseAddr("fe80::1"),
expected: false,
},
{
name: "IPv6 loopback",
setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/16"),
},
testIP: netip.MustParseAddr("::1"),
expected: true,
},
}
for _, tt := range tests {
@@ -202,3 +171,90 @@ func TestLocalIPManager_AllInterfaces(t *testing.T) {
})
}
}
// MapImplementation is a version using map[string]struct{}
type MapImplementation struct {
localIPs map[string]struct{}
}
func BenchmarkIPChecks(b *testing.B) {
interfaces := make([]net.IP, 16)
for i := range interfaces {
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
}
// Setup bitmap
bitmapManager := newLocalIPManager()
for _, ip := range interfaces[:8] { // Add half of IPs
bitmapManager.setBitmapBit(ip)
}
// Setup map version
mapManager := &MapImplementation{
localIPs: make(map[string]struct{}),
}
for _, ip := range interfaces[:8] {
mapManager.localIPs[ip.String()] = struct{}{}
}
b.Run("Bitmap_Hit", func(b *testing.B) {
ip := interfaces[4]
b.ResetTimer()
for i := 0; i < b.N; i++ {
bitmapManager.checkBitmapBit(ip)
}
})
b.Run("Bitmap_Miss", func(b *testing.B) {
ip := interfaces[12]
b.ResetTimer()
for i := 0; i < b.N; i++ {
bitmapManager.checkBitmapBit(ip)
}
})
b.Run("Map_Hit", func(b *testing.B) {
ip := interfaces[4]
b.ResetTimer()
for i := 0; i < b.N; i++ {
// nolint:gosimple
_ = mapManager.localIPs[ip.String()]
}
})
b.Run("Map_Miss", func(b *testing.B) {
ip := interfaces[12]
b.ResetTimer()
for i := 0; i < b.N; i++ {
// nolint:gosimple
_ = mapManager.localIPs[ip.String()]
}
})
}
func BenchmarkWGPosition(b *testing.B) {
wgIP := net.ParseIP("10.10.0.1")
// Create two managers - one checks WG IP first, other checks it last
b.Run("WG_First", func(b *testing.B) {
bm := newLocalIPManager()
bm.setBitmapBit(wgIP)
b.ResetTimer()
for i := 0; i < b.N; i++ {
bm.checkBitmapBit(wgIP)
}
})
b.Run("WG_Last", func(b *testing.B) {
bm := newLocalIPManager()
// Fill with other IPs first
for i := 0; i < 15; i++ {
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))
}
bm.setBitmapBit(wgIP) // Add WG IP last
b.ResetTimer()
for i := 0; i < b.N; i++ {
bm.checkBitmapBit(wgIP)
}
})
}

View File

@@ -13,6 +13,8 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
var (
errInvalidIPHeaderLength = errors.New("invalid IP header length")
)
@@ -23,33 +25,10 @@ const (
destinationPortOffset = 2
// IP address offsets in IPv4 header
ipv4SrcOffset = 12
ipv4DstOffset = 16
// IP address offsets in IPv6 header
ipv6SrcOffset = 8
ipv6DstOffset = 24
// IPv6 fixed header length
ipv6HeaderLen = 40
sourceIPOffset = 12
destinationIPOffset = 16
)
// ipHeaderLen returns the IP header length based on the decoded layer type.
func ipHeaderLen(d *decoder) (int, error) {
switch d.decoded[0] {
case layers.LayerTypeIPv4:
n := int(d.ip4.IHL) * 4
if n < 20 {
return 0, errInvalidIPHeaderLength
}
return n, nil
case layers.LayerTypeIPv6:
return ipv6HeaderLen, nil
default:
return 0, fmt.Errorf("unknown IP layer: %v", d.decoded[0])
}
}
// ipv4Checksum calculates IPv4 header checksum.
func ipv4Checksum(header []byte) uint16 {
if len(header) < 20 {
@@ -255,13 +234,14 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
return false
}
_, dstIP := extractPacketIPs(packetData, d)
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
translatedIP, exists := m.getDNATTranslation(dstIP)
if !exists {
return false
}
if err := m.rewritePacketIP(packetData, d, translatedIP, false); err != nil {
if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil {
m.logger.Error1("failed to rewrite packet destination: %v", err)
return false
}
@@ -276,13 +256,14 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
return false
}
srcIP, _ := extractPacketIPs(packetData, d)
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
originalIP, exists := m.findReverseDNATMapping(srcIP)
if !exists {
return false
}
if err := m.rewritePacketIP(packetData, d, originalIP, true); err != nil {
if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil {
m.logger.Error1("failed to rewrite packet source: %v", err)
return false
}
@@ -291,96 +272,38 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
return true
}
// extractPacketIPs extracts src and dst IP addresses directly from raw packet bytes.
func extractPacketIPs(packetData []byte, d *decoder) (src, dst netip.Addr) {
switch d.decoded[0] {
case layers.LayerTypeIPv4:
src = netip.AddrFrom4([4]byte{packetData[ipv4SrcOffset], packetData[ipv4SrcOffset+1], packetData[ipv4SrcOffset+2], packetData[ipv4SrcOffset+3]})
dst = netip.AddrFrom4([4]byte{packetData[ipv4DstOffset], packetData[ipv4DstOffset+1], packetData[ipv4DstOffset+2], packetData[ipv4DstOffset+3]})
case layers.LayerTypeIPv6:
src = netip.AddrFrom16([16]byte(packetData[ipv6SrcOffset : ipv6SrcOffset+16]))
dst = netip.AddrFrom16([16]byte(packetData[ipv6DstOffset : ipv6DstOffset+16]))
}
return src, dst
}
// rewritePacketIP replaces a source (isSource=true) or destination IP address in the packet and updates checksums.
func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, isSource bool) error {
hdrLen, err := ipHeaderLen(d)
if err != nil {
return err
}
switch d.decoded[0] {
case layers.LayerTypeIPv4:
return m.rewriteIPv4(packetData, d, newIP, hdrLen, isSource)
case layers.LayerTypeIPv6:
return m.rewriteIPv6(packetData, d, newIP, hdrLen, isSource)
default:
return fmt.Errorf("unknown IP layer: %v", d.decoded[0])
}
}
func (m *Manager) rewriteIPv4(packetData []byte, d *decoder, newIP netip.Addr, hdrLen int, isSource bool) error {
// rewritePacketIP replaces an IP address (source or destination) in the packet and updates checksums.
func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, ipOffset int) error {
if !newIP.Is4() {
return fmt.Errorf("cannot write IPv6 address into IPv4 packet")
}
offset := ipv4DstOffset
if isSource {
offset = ipv4SrcOffset
return ErrIPv4Only
}
var oldIP [4]byte
copy(oldIP[:], packetData[offset:offset+4])
copy(oldIP[:], packetData[ipOffset:ipOffset+4])
newIPBytes := newIP.As4()
copy(packetData[offset:offset+4], newIPBytes[:])
// Recalculate IPv4 header checksum
copy(packetData[ipOffset:ipOffset+4], newIPBytes[:])
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return errInvalidIPHeaderLength
}
binary.BigEndian.PutUint16(packetData[10:12], 0)
binary.BigEndian.PutUint16(packetData[10:12], ipv4Checksum(packetData[:hdrLen]))
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
// Update transport checksums incrementally
if len(d.decoded) > 1 {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
m.updateTCPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
m.updateUDPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
case layers.LayerTypeICMPv4:
m.updateICMPChecksum(packetData, hdrLen)
m.updateICMPChecksum(packetData, ipHeaderLen)
}
}
return nil
}
func (m *Manager) rewriteIPv6(packetData []byte, d *decoder, newIP netip.Addr, hdrLen int, isSource bool) error {
if !newIP.Is6() {
return fmt.Errorf("cannot write IPv4 address into IPv6 packet")
}
offset := ipv6DstOffset
if isSource {
offset = ipv6SrcOffset
}
var oldIP [16]byte
copy(oldIP[:], packetData[offset:offset+16])
newIPBytes := newIP.As16()
copy(packetData[offset:offset+16], newIPBytes[:])
// IPv6 has no header checksum, only update transport checksums
if len(d.decoded) > 1 {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
case layers.LayerTypeICMPv6:
// ICMPv6 checksum includes pseudo-header with addresses, use incremental update
m.updateICMPv6Checksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
}
}
return nil
}
@@ -428,20 +351,6 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
}
// updateICMPv6Checksum updates ICMPv6 checksum after address change.
// ICMPv6 uses a pseudo-header (like TCP/UDP), so incremental update applies.
func (m *Manager) updateICMPv6Checksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
icmpStart := ipHeaderLen
if len(packetData) < icmpStart+4 {
return
}
checksumOffset := icmpStart + 2
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
// incrementalUpdate performs incremental checksum update per RFC 1624.
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
sum := uint32(^oldChecksum)
@@ -494,14 +403,14 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
}
// addPortRedirection adds a port redirection rule.
func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, originalPort, translatedPort uint16) error {
func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error {
m.portDNATMutex.Lock()
defer m.portDNATMutex.Unlock()
rule := portDNATRule{
protocol: protocol,
origPort: originalPort,
targetPort: translatedPort,
origPort: sourcePort,
targetPort: targetPort,
targetIP: targetIP,
}
@@ -513,7 +422,7 @@ func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.Laye
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
// TODO: also delegate to nativeFirewall when available for kernel WG mode
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
var layerType gopacket.LayerType
switch protocol {
case firewall.ProtocolTCP:
@@ -524,16 +433,16 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
return fmt.Errorf("unsupported protocol: %s", protocol)
}
return m.addPortRedirection(localAddr, layerType, originalPort, translatedPort)
return m.addPortRedirection(localAddr, layerType, sourcePort, targetPort)
}
// removePortRedirection removes a port redirection rule.
func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, originalPort, translatedPort uint16) error {
func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error {
m.portDNATMutex.Lock()
defer m.portDNATMutex.Unlock()
m.portDNATRules = slices.DeleteFunc(m.portDNATRules, func(rule portDNATRule) bool {
return rule.protocol == protocol && rule.origPort == originalPort && rule.targetPort == translatedPort && rule.targetIP.Compare(targetIP) == 0
return rule.protocol == protocol && rule.origPort == sourcePort && rule.targetPort == targetPort && rule.targetIP.Compare(targetIP) == 0
})
if len(m.portDNATRules) == 0 {
@@ -544,7 +453,7 @@ func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.L
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
var layerType gopacket.LayerType
switch protocol {
case firewall.ProtocolTCP:
@@ -555,23 +464,23 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return fmt.Errorf("unsupported protocol: %s", protocol)
}
return m.removePortRedirection(localAddr, layerType, originalPort, translatedPort)
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort)
}
// AddOutputDNAT delegates to the native firewall if available.
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
if m.nativeFirewall == nil {
return fmt.Errorf("output DNAT not supported without native firewall")
}
return m.nativeFirewall.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.nativeFirewall.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveOutputDNAT delegates to the native firewall if available.
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
if m.nativeFirewall == nil {
return nil
}
return m.nativeFirewall.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.nativeFirewall.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
@@ -623,12 +532,12 @@ func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP neti
// rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum.
func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
hdrLen, err := ipHeaderLen(d)
if err != nil {
return err
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return errInvalidIPHeaderLength
}
tcpStart := hdrLen
tcpStart := ipHeaderLen
if len(packetData) < tcpStart+4 {
return fmt.Errorf("packet too short for TCP header")
}
@@ -654,12 +563,12 @@ func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16,
// rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum.
func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
hdrLen, err := ipHeaderLen(d)
if err != nil {
return err
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return errInvalidIPHeaderLength
}
udpStart := hdrLen
udpStart := ipHeaderLen
if len(packetData) < udpStart+8 {
return fmt.Errorf("packet too short for UDP header")
}

View File

@@ -342,17 +342,12 @@ func BenchmarkDNATMemoryAllocations(b *testing.B) {
// Parse the packet fresh each time to get a clean decoder
d := &decoder{decoded: []gopacket.LayerType{}}
d.parser4 = gopacket.NewDecodingLayerParser(
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser4.IgnoreUnsupported = true
d.parser6 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv6,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser6.IgnoreUnsupported = true
err = d.decodePacket(testPacket)
d.parser.IgnoreUnsupported = true
err = d.parser.DecodeLayers(testPacket, &d.decoded)
assert.NoError(b, err)
manager.translateOutboundDNAT(testPacket, d)
@@ -376,17 +371,12 @@ func BenchmarkDirectIPExtraction(b *testing.B) {
b.Run("decoder_extraction", func(b *testing.B) {
// Create decoder once for comparison
d := &decoder{decoded: []gopacket.LayerType{}}
d.parser4 = gopacket.NewDecodingLayerParser(
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser4.IgnoreUnsupported = true
d.parser6 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv6,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser6.IgnoreUnsupported = true
err := d.decodePacket(packet)
d.parser.IgnoreUnsupported = true
err := d.parser.DecodeLayers(packet, &d.decoded)
assert.NoError(b, err)
for i := 0; i < b.N; i++ {

View File

@@ -86,18 +86,13 @@ func parsePacket(t testing.TB, packetData []byte) *decoder {
d := &decoder{
decoded: []gopacket.LayerType{},
}
d.parser4 = gopacket.NewDecodingLayerParser(
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser4.IgnoreUnsupported = true
d.parser6 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv6,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser6.IgnoreUnsupported = true
d.parser.IgnoreUnsupported = true
err := d.decodePacket(packetData)
err := d.parser.DecodeLayers(packetData, &d.decoded)
require.NoError(t, err)
return d
}

View File

@@ -2,9 +2,7 @@ package uspfilter
import (
"fmt"
"net"
"net/netip"
"strconv"
"time"
"github.com/google/gopacket"
@@ -114,13 +112,10 @@ func (t *PacketTrace) AddResultWithForwarder(stage PacketStage, message string,
}
func (p *PacketBuilder) Build() ([]byte, error) {
ipLayer, err := p.buildIPLayer()
if err != nil {
return nil, err
}
pktLayers := []gopacket.SerializableLayer{ipLayer}
ip := p.buildIPLayer()
pktLayers := []gopacket.SerializableLayer{ip}
transportLayer, err := p.buildTransportLayer(ipLayer)
transportLayer, err := p.buildTransportLayer(ip)
if err != nil {
return nil, err
}
@@ -134,43 +129,30 @@ func (p *PacketBuilder) Build() ([]byte, error) {
return serializePacket(pktLayers)
}
func (p *PacketBuilder) buildIPLayer() (gopacket.SerializableLayer, error) {
if p.SrcIP.Is4() != p.DstIP.Is4() {
return nil, fmt.Errorf("mixed address families: src=%s dst=%s", p.SrcIP, p.DstIP)
}
proto := getIPProtocolNumber(p.Protocol, p.SrcIP.Is6())
if p.SrcIP.Is6() {
return &layers.IPv6{
Version: 6,
HopLimit: 64,
NextHeader: proto,
SrcIP: p.SrcIP.AsSlice(),
DstIP: p.DstIP.AsSlice(),
}, nil
}
func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
return &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: proto,
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
SrcIP: p.SrcIP.AsSlice(),
DstIP: p.DstIP.AsSlice(),
}, nil
}
}
func (p *PacketBuilder) buildTransportLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
func (p *PacketBuilder) buildTransportLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
switch p.Protocol {
case "tcp":
return p.buildTCPLayer(ipLayer)
return p.buildTCPLayer(ip)
case "udp":
return p.buildUDPLayer(ipLayer)
return p.buildUDPLayer(ip)
case "icmp":
return p.buildICMPLayer(ipLayer)
return p.buildICMPLayer()
default:
return nil, fmt.Errorf("unsupported protocol: %s", p.Protocol)
}
}
func (p *PacketBuilder) buildTCPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
func (p *PacketBuilder) buildTCPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
tcp := &layers.TCP{
SrcPort: layers.TCPPort(p.SrcPort),
DstPort: layers.TCPPort(p.DstPort),
@@ -182,44 +164,24 @@ func (p *PacketBuilder) buildTCPLayer(ipLayer gopacket.SerializableLayer) ([]gop
PSH: p.TCPState != nil && p.TCPState.PSH,
URG: p.TCPState != nil && p.TCPState.URG,
}
if nl, ok := ipLayer.(gopacket.NetworkLayer); ok {
if err := tcp.SetNetworkLayerForChecksum(nl); err != nil {
return nil, fmt.Errorf("set network layer for TCP checksum: %w", err)
}
if err := tcp.SetNetworkLayerForChecksum(ip); err != nil {
return nil, fmt.Errorf("set network layer for TCP checksum: %w", err)
}
return []gopacket.SerializableLayer{tcp}, nil
}
func (p *PacketBuilder) buildUDPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
func (p *PacketBuilder) buildUDPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
udp := &layers.UDP{
SrcPort: layers.UDPPort(p.SrcPort),
DstPort: layers.UDPPort(p.DstPort),
}
if nl, ok := ipLayer.(gopacket.NetworkLayer); ok {
if err := udp.SetNetworkLayerForChecksum(nl); err != nil {
return nil, fmt.Errorf("set network layer for UDP checksum: %w", err)
}
if err := udp.SetNetworkLayerForChecksum(ip); err != nil {
return nil, fmt.Errorf("set network layer for UDP checksum: %w", err)
}
return []gopacket.SerializableLayer{udp}, nil
}
func (p *PacketBuilder) buildICMPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
if p.SrcIP.Is6() || p.DstIP.Is6() {
icmp := &layers.ICMPv6{
TypeCode: layers.CreateICMPv6TypeCode(p.ICMPType, p.ICMPCode),
}
if nl, ok := ipLayer.(gopacket.NetworkLayer); ok {
_ = icmp.SetNetworkLayerForChecksum(nl)
}
if p.ICMPType == layers.ICMPv6TypeEchoRequest || p.ICMPType == layers.ICMPv6TypeEchoReply {
echo := &layers.ICMPv6Echo{
Identifier: 1,
SeqNumber: 1,
}
return []gopacket.SerializableLayer{icmp, echo}, nil
}
return []gopacket.SerializableLayer{icmp}, nil
}
func (p *PacketBuilder) buildICMPLayer() ([]gopacket.SerializableLayer, error) {
icmp := &layers.ICMPv4{
TypeCode: layers.CreateICMPv4TypeCode(p.ICMPType, p.ICMPCode),
}
@@ -242,17 +204,14 @@ func serializePacket(layers []gopacket.SerializableLayer) ([]byte, error) {
return buf.Bytes(), nil
}
func getIPProtocolNumber(protocol fw.Protocol, isV6 bool) layers.IPProtocol {
func getIPProtocolNumber(protocol fw.Protocol) int {
switch protocol {
case fw.ProtocolTCP:
return layers.IPProtocolTCP
return int(layers.IPProtocolTCP)
case fw.ProtocolUDP:
return layers.IPProtocolUDP
return int(layers.IPProtocolUDP)
case fw.ProtocolICMP:
if isV6 {
return layers.IPProtocolICMPv6
}
return layers.IPProtocolICMPv4
return int(layers.IPProtocolICMPv4)
default:
return 0
}
@@ -275,7 +234,7 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
trace := &PacketTrace{Direction: direction}
// Initial packet decoding
if err := d.decodePacket(packetData); err != nil {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
trace.AddResult(StageReceived, fmt.Sprintf("Failed to decode packet: %v", err), false)
return trace
}
@@ -297,8 +256,6 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
trace.DestinationPort = uint16(d.udp.DstPort)
case layers.LayerTypeICMPv4:
trace.Protocol = "ICMP"
case layers.LayerTypeICMPv6:
trace.Protocol = "ICMPv6"
}
trace.AddResult(StageReceived, fmt.Sprintf("Received %s packet: %s:%d -> %s:%d",
@@ -362,13 +319,6 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string {
flags&conntrack.TCPFin != 0)
case layers.LayerTypeICMPv4:
msg += fmt.Sprintf(" (ICMP ID=%d, Seq=%d)", d.icmp4.Id, d.icmp4.Seq)
case layers.LayerTypeICMPv6:
var id, seq uint16
if len(d.icmp6.Payload) >= 4 {
id = uint16(d.icmp6.Payload[0])<<8 | uint16(d.icmp6.Payload[1])
seq = uint16(d.icmp6.Payload[2])<<8 | uint16(d.icmp6.Payload[3])
}
msg += fmt.Sprintf(" (ICMPv6 ID=%d, Seq=%d)", id, seq)
}
return msg
}
@@ -445,7 +395,7 @@ func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP n
trace.AddResult(StageRouteACL, msg, allowed)
if allowed && m.forwarder.Load() != nil {
m.addForwardingResult(trace, "proxy-remote", net.JoinHostPort(dstIP.String(), strconv.Itoa(int(dstPort))), true)
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
}
trace.AddResult(StageCompleted, msgProcessingCompleted, allowed)
@@ -465,7 +415,7 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
if err := d.decodePacket(packetData); err != nil {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
trace.AddResult(StageCompleted, "Packet dropped - decode error", false)
return trace
}
@@ -484,7 +434,7 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr
func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool {
portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d)
if portDNATApplied {
if err := d.decodePacket(packetData); err != nil {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false)
return true
}
@@ -494,7 +444,7 @@ func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *de
nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d)
if nat1to1Applied {
if err := d.decodePacket(packetData); err != nil {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false)
return true
}
@@ -559,7 +509,7 @@ func (m *Manager) traceInbound1to1NAT(trace *PacketTrace, packetData []byte, d *
return false
}
srcIP, _ := extractPacketIPs(packetData, d)
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
translated := m.translateInboundReverse(packetData, d)
if translated {
@@ -589,7 +539,7 @@ func (m *Manager) traceOutbound1to1NAT(trace *PacketTrace, packetData []byte, d
return false
}
_, dstIP := extractPacketIPs(packetData, d)
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
translated := m.translateOutboundDNAT(packetData, d)
if translated {

View File

@@ -119,7 +119,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
if err != nil {
return fmt.Errorf("failed to parse endpoint address: %w", err)
}
addrPort := netip.AddrPortFrom(addr.Unmap(), uint16(endpoint.Port))
addrPort := netip.AddrPortFrom(addr, uint16(endpoint.Port))
c.activityRecorder.UpsertAddress(peerKey, addrPort)
}
return nil

View File

@@ -2,7 +2,7 @@ package device
// TunAdapter is an interface for create tun device from external service
type TunAdapter interface {
ConfigureInterface(address string, addressV6 string, mtu int, dns string, searchDomains string, routes string) (int, error)
ConfigureInterface(address string, mtu int, dns string, searchDomains string, routes string) (int, error)
UpdateAddr(address string) error
ProtectSocket(fd int32) bool
}

View File

@@ -63,7 +63,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
searchDomainsToString = ""
}
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.address.IPv6String(), int(t.mtu), dns, searchDomainsToString, routesString)
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), int(t.mtu), dns, searchDomainsToString, routesString)
if err != nil {
log.Errorf("failed to create Android interface: %s", err)
return nil, err

View File

@@ -131,32 +131,23 @@ func (t *TunDevice) Device() *device.Device {
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
func (t *TunDevice) assignAddr() error {
if out, err := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String()).CombinedOutput(); err != nil {
return fmt.Errorf("add v4 address: %s: %w", string(out), err)
cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String())
if out, err := cmd.CombinedOutput(); err != nil {
log.Errorf("adding address command '%v' failed with output: %s", cmd.String(), out)
return err
}
// Assign a dummy link-local so macOS enables IPv6 on the tun device.
// When a real overlay v6 is present, use that instead.
v6Addr := "fe80::/64"
if t.address.HasIPv6() {
v6Addr = t.address.IPv6String()
}
if out, err := exec.Command("ifconfig", t.name, "inet6", v6Addr).CombinedOutput(); err != nil {
log.Warnf("failed to assign IPv6 address %s, continuing v4-only: %s: %v", v6Addr, string(out), err)
t.address.ClearIPv6()
// dummy ipv6 so routing works
cmd = exec.Command("ifconfig", t.name, "inet6", "fe80::/64")
if out, err := cmd.CombinedOutput(); err != nil {
log.Debugf("adding address command '%v' failed with output: %s", cmd.String(), out)
}
if out, err := exec.Command("route", "add", "-net", t.address.Network.String(), "-interface", t.name).CombinedOutput(); err != nil {
return fmt.Errorf("add route %s via %s: %s: %w", t.address.Network, t.name, string(out), err)
routeCmd := exec.Command("route", "add", "-net", t.address.Network.String(), "-interface", t.name)
if out, err := routeCmd.CombinedOutput(); err != nil {
log.Errorf("adding route command '%v' failed with output: %s", routeCmd.String(), out)
return err
}
if t.address.HasIPv6() {
if out, err := exec.Command("route", "add", "-inet6", "-net", t.address.IPv6Net.String(), "-interface", t.name).CombinedOutput(); err != nil {
log.Warnf("failed to add route %s via %s, continuing v4-only: %s: %v", t.address.IPv6Net, t.name, string(out), err)
t.address.ClearIPv6()
}
}
return nil
}

View File

@@ -151,11 +151,8 @@ func (t *TunDevice) MTU() uint16 {
return t.mtu
}
// UpdateAddr updates the device address. On iOS the tunnel is managed by the
// NetworkExtension, so we only store the new value. The extension picks up the
// change on the next tunnel reconfiguration.
func (t *TunDevice) UpdateAddr(addr wgaddr.Address) error {
t.address = addr
func (t *TunDevice) UpdateAddr(_ wgaddr.Address) error {
// todo implement
return nil
}

View File

@@ -173,7 +173,7 @@ func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
// assignAddr Adds IP address to the tunnel interface
func (t *TunKernelDevice) assignAddr() error {
return t.link.assignAddr(&t.address)
return t.link.assignAddr(t.address)
}
func (t *TunKernelDevice) GetNet() *netstack.Net {

View File

@@ -3,7 +3,6 @@ package device
import (
"errors"
"fmt"
"net/netip"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/conn"
@@ -64,12 +63,8 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
return nil, fmt.Errorf("last ip: %w", err)
}
addresses := []netip.Addr{t.address.IP}
if t.address.HasIPv6() {
addresses = append(addresses, t.address.IPv6)
}
log.Debugf("netstack using addresses: %v", addresses)
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, addresses, dnsAddr, int(t.mtu))
log.Debugf("netstack using address: %s", t.address.IP)
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, int(t.mtu))
log.Debugf("netstack using dns address: %s", dnsAddr)
tunIface, net, err := t.nsTun.Create()
if err != nil {

View File

@@ -16,7 +16,7 @@ import (
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type TunDevice struct {
type USPDevice struct {
name string
address wgaddr.Address
port int
@@ -30,10 +30,10 @@ type TunDevice struct {
configurer WGConfigurer
}
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind) *TunDevice {
func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind) *USPDevice {
log.Infof("using userspace bind mode")
return &TunDevice{
return &USPDevice{
name: name,
address: address,
port: port,
@@ -43,7 +43,7 @@ func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu
}
}
func (t *TunDevice) Create() (WGConfigurer, error) {
func (t *USPDevice) Create() (WGConfigurer, error) {
log.Info("create tun interface")
tunIface, err := tun.CreateTUN(t.name, int(t.mtu))
if err != nil {
@@ -75,7 +75,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return t.configurer, nil
}
func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
func (t *USPDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
if t.device == nil {
return nil, fmt.Errorf("device is not ready yet")
}
@@ -95,12 +95,12 @@ func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *TunDevice) UpdateAddr(address wgaddr.Address) error {
func (t *USPDevice) UpdateAddr(address wgaddr.Address) error {
t.address = address
return t.assignAddr()
}
func (t *TunDevice) Close() error {
func (t *USPDevice) Close() error {
if t.configurer != nil {
t.configurer.Close()
}
@@ -115,39 +115,39 @@ func (t *TunDevice) Close() error {
return nil
}
func (t *TunDevice) WgAddress() wgaddr.Address {
func (t *USPDevice) WgAddress() wgaddr.Address {
return t.address
}
func (t *TunDevice) MTU() uint16 {
func (t *USPDevice) MTU() uint16 {
return t.mtu
}
func (t *TunDevice) DeviceName() string {
func (t *USPDevice) DeviceName() string {
return t.name
}
func (t *TunDevice) FilteredDevice() *FilteredDevice {
func (t *USPDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice
}
// Device returns the wireguard device
func (t *TunDevice) Device() *device.Device {
func (t *USPDevice) Device() *device.Device {
return t.device
}
// assignAddr Adds IP address to the tunnel interface
func (t *TunDevice) assignAddr() error {
func (t *USPDevice) assignAddr() error {
link := newWGLink(t.name)
return link.assignAddr(&t.address)
return link.assignAddr(t.address)
}
func (t *TunDevice) GetNet() *netstack.Net {
func (t *USPDevice) GetNet() *netstack.Net {
return nil
}
// GetICEBind returns the ICEBind instance
func (t *TunDevice) GetICEBind() EndpointManager {
func (t *USPDevice) GetICEBind() EndpointManager {
return t.iceBind
}

View File

@@ -87,21 +87,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
err = nbiface.Set()
if err != nil {
t.device.Close()
return nil, fmt.Errorf("set IPv4 interface MTU: %s", err)
}
if t.address.HasIPv6() {
nbiface6, err := luid.IPInterface(windows.AF_INET6)
if err != nil {
log.Warnf("failed to get IPv6 interface for MTU, continuing v4-only: %v", err)
t.address.ClearIPv6()
} else {
nbiface6.NLMTU = uint32(t.mtu)
if err := nbiface6.Set(); err != nil {
log.Warnf("failed to set IPv6 interface MTU, continuing v4-only: %v", err)
t.address.ClearIPv6()
}
}
return nil, fmt.Errorf("got error when getting setting the interface mtu: %s", err)
}
err = t.assignAddr()
if err != nil {
@@ -192,21 +178,8 @@ func (t *TunDevice) GetInterfaceGUIDString() (string, error) {
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
func (t *TunDevice) assignAddr() error {
luid := winipcfg.LUID(t.nativeTunDevice.LUID())
v4Prefix := t.address.Prefix()
if t.address.HasIPv6() {
v6Prefix := t.address.IPv6Prefix()
log.Debugf("adding addresses %s, %s to interface: %s", v4Prefix, v6Prefix, t.name)
if err := luid.SetIPAddresses([]netip.Prefix{v4Prefix, v6Prefix}); err != nil {
log.Warnf("failed to assign dual-stack addresses, retrying v4-only: %v", err)
t.address.ClearIPv6()
return luid.SetIPAddresses([]netip.Prefix{v4Prefix})
}
return nil
}
log.Debugf("adding address %s to interface: %s", v4Prefix, t.name)
return luid.SetIPAddresses([]netip.Prefix{v4Prefix})
log.Debugf("adding address %s to interface: %s", t.address.IP, t.name)
return luid.SetIPAddresses([]netip.Prefix{netip.MustParsePrefix(t.address.String())})
}
func (t *TunDevice) GetNet() *netstack.Net {

View File

@@ -0,0 +1,8 @@
//go:build (!linux && !freebsd) || android
package device
// WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only)
func WireGuardModuleIsLoaded() bool {
return false
}

View File

@@ -0,0 +1,18 @@
package device
// WireGuardModuleIsLoaded check if kernel support wireguard
func WireGuardModuleIsLoaded() bool {
// Despite the fact FreeBSD natively support Wireguard (https://github.com/WireGuard/wireguard-freebsd)
// we are currently do not use it, since it is required to add wireguard kernel support to
// - https://github.com/netbirdio/netbird/tree/main/sharedsock
// - https://github.com/mdlayher/socket
// TODO: implement kernel space
return false
}
// ModuleTunIsLoaded check if tun module exist, if is not attempt to load it
func ModuleTunIsLoaded() bool {
// Assume tun supported by freebsd kernel by default
// TODO: implement check for module loaded in kernel or build-it
return true
}

View File

@@ -1,13 +0,0 @@
//go:build !linux || android
package device
// WireGuardModuleIsLoaded reports whether the kernel WireGuard module is available.
func WireGuardModuleIsLoaded() bool {
return false
}
// ModuleTunIsLoaded reports whether the tun device is available.
func ModuleTunIsLoaded() bool {
return true
}

View File

@@ -2,7 +2,6 @@ package device
import (
"fmt"
"os/exec"
log "github.com/sirupsen/logrus"
@@ -58,32 +57,32 @@ func (l *wgLink) up() error {
return nil
}
func (l *wgLink) assignAddr(address *wgaddr.Address) error {
func (l *wgLink) assignAddr(address wgaddr.Address) error {
link, err := freebsd.LinkByName(l.name)
if err != nil {
return fmt.Errorf("link by name: %w", err)
}
ip := address.IP.String()
// Convert prefix length to hex netmask
prefixLen := address.Network.Bits()
if !address.IP.Is4() {
return fmt.Errorf("IPv6 not supported for interface assignment")
}
maskBits := uint32(0xffffffff) << (32 - prefixLen)
mask := fmt.Sprintf("0x%08x", maskBits)
log.Infof("assign addr %s mask %s to %s interface", address.IP, mask, l.name)
log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name)
if err := link.AssignAddr(address.IP.String(), mask); err != nil {
err = link.AssignAddr(ip, mask)
if err != nil {
return fmt.Errorf("assign addr: %w", err)
}
if address.HasIPv6() {
log.Infof("assign IPv6 addr %s to %s interface", address.IPv6String(), l.name)
cmd := exec.Command("ifconfig", l.name, "inet6", address.IPv6String())
if out, err := cmd.CombinedOutput(); err != nil {
log.Warnf("failed to assign IPv6 address %s to %s, continuing v4-only: %s: %v", address.IPv6String(), l.name, string(out), err)
address.ClearIPv6()
}
}
if err := link.Up(); err != nil {
err = link.Up()
if err != nil {
return fmt.Errorf("up: %w", err)
}

View File

@@ -4,8 +4,6 @@ package device
import (
"fmt"
"net"
"net/netip"
"os"
log "github.com/sirupsen/logrus"
@@ -94,7 +92,7 @@ func (l *wgLink) up() error {
return nil
}
func (l *wgLink) assignAddr(address *wgaddr.Address) error {
func (l *wgLink) assignAddr(address wgaddr.Address) error {
//delete existing addresses
list, err := netlink.AddrList(l, 0)
if err != nil {
@@ -112,16 +110,20 @@ func (l *wgLink) assignAddr(address *wgaddr.Address) error {
}
name := l.attrs.Name
addrStr := address.String()
if err := l.addAddr(name, address.Prefix()); err != nil {
return err
log.Debugf("adding address %s to interface: %s", addrStr, name)
addr, err := netlink.ParseAddr(addrStr)
if err != nil {
return fmt.Errorf("parse addr: %w", err)
}
if address.HasIPv6() {
if err := l.addAddr(name, address.IPv6Prefix()); err != nil {
log.Warnf("failed to assign IPv6 address %s to %s, continuing v4-only: %v", address.IPv6Prefix(), name, err)
address.ClearIPv6()
}
err = netlink.AddrAdd(l, addr)
if os.IsExist(err) {
log.Infof("interface %s already has the address: %s", name, addrStr)
} else if err != nil {
return fmt.Errorf("add addr: %w", err)
}
// On linux, the link must be brought up
@@ -131,22 +133,3 @@ func (l *wgLink) assignAddr(address *wgaddr.Address) error {
return nil
}
func (l *wgLink) addAddr(ifaceName string, prefix netip.Prefix) error {
log.Debugf("adding address %s to interface: %s", prefix, ifaceName)
addr := &netlink.Addr{
IPNet: &net.IPNet{
IP: prefix.Addr().AsSlice(),
Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()),
},
}
if err := netlink.AddrAdd(l, addr); os.IsExist(err) {
log.Infof("interface %s already has the address: %s", ifaceName, prefix)
} else if err != nil {
return fmt.Errorf("add addr %s: %w", prefix, err)
}
return nil
}

View File

@@ -57,7 +57,7 @@ type wgProxyFactory interface {
type WGIFaceOpts struct {
IFaceName string
Address wgaddr.Address
Address string
WGPort int
WGPrivKey string
MTU uint16
@@ -141,11 +141,16 @@ func (w *WGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) {
}
// UpdateAddr updates address of the interface
func (w *WGIface) UpdateAddr(newAddr wgaddr.Address) error {
func (w *WGIface) UpdateAddr(newAddr string) error {
w.mu.Lock()
defer w.mu.Unlock()
return w.tun.UpdateAddr(newAddr)
addr, err := wgaddr.ParseWGAddress(newAddr)
if err != nil {
return err
}
return w.tun.UpdateAddr(addr)
}
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist

View File

@@ -4,17 +4,23 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
if netstack.IsEnabled() {
wgIFace := &WGIface{
userspaceBind: true,
tun: device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
}
return wgIFace, nil
@@ -22,7 +28,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{
userspaceBind: true,
tun: device.NewTunDevice(opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
}
return wgIFace, nil

View File

@@ -0,0 +1,35 @@
//go:build !ios
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
var tun WGTunDevice
if netstack.IsEnabled() {
tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
} else {
tun = device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
}
wgIFace := &WGIface{
userspaceBind: true,
tun: tun,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
}
return wgIFace, nil
}

View File

@@ -0,0 +1,41 @@
//go:build freebsd
package iface
import (
"fmt"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{}
if netstack.IsEnabled() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
return wgIFace, nil
}
if device.ModuleTunIsLoaded() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
return wgIFace, nil
}
return nil, fmt.Errorf("couldn't check or load tun module")
}

View File

@@ -5,15 +5,21 @@ package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace := &WGIface{
tun: device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd),
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd),
userspaceBind: true,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
}

View File

@@ -4,15 +4,21 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace creates a new WireGuard interface for WASM (always uses netstack mode)
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
relayBind := bind.NewRelayBindJS()
wgIface := &WGIface{
tun: device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, relayBind, netstack.ListenAddr()),
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, relayBind, netstack.ListenAddr()),
userspaceBind: true,
wgProxyFactory: wgproxy.NewUSPFactory(relayBind, opts.MTU),
}

View File

@@ -3,40 +3,44 @@
package iface
import (
"errors"
"fmt"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{}
if netstack.IsEnabled() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
return &WGIface{
tun: device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
userspaceBind: true,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
}, nil
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
return wgIFace, nil
}
if device.WireGuardModuleIsLoaded() {
return &WGIface{
tun: device.NewKernelDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, opts.TransportNet),
wgProxyFactory: wgproxy.NewKernelFactory(opts.WGPort, opts.MTU),
}, nil
wgIFace.tun = device.NewKernelDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, opts.TransportNet)
wgIFace.wgProxyFactory = wgproxy.NewKernelFactory(opts.WGPort, opts.MTU)
return wgIFace, nil
}
if device.ModuleTunIsLoaded() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
return &WGIface{
tun: device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind),
userspaceBind: true,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
}, nil
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
return wgIFace, nil
}
return nil, errors.New("tun module not available")
return nil, fmt.Errorf("couldn't check or load tun module")
}

View File

@@ -1,28 +1,33 @@
//go:build !linux && !ios && !android && !js
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
wgaddr "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
var tun WGTunDevice
if netstack.IsEnabled() {
tun = device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
} else {
tun = device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
tun = device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
}
return &WGIface{
wgIFace := &WGIface{
userspaceBind: true,
tun: tun,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
}, nil
}
return wgIFace, nil
}

View File

@@ -16,7 +16,6 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
@@ -49,7 +48,7 @@ func TestWGIface_UpdateAddr(t *testing.T) {
opts := WGIFaceOpts{
IFaceName: ifaceName,
Address: wgaddr.MustParseWGAddress(addr),
Address: addr,
WGPort: wgPort,
WGPrivKey: key,
MTU: DefaultMTU,
@@ -85,7 +84,7 @@ func TestWGIface_UpdateAddr(t *testing.T) {
//update WireGuard address
addr = "100.64.0.2/8"
err = iface.UpdateAddr(wgaddr.MustParseWGAddress(addr))
err = iface.UpdateAddr(addr)
if err != nil {
t.Fatal(err)
}
@@ -131,7 +130,7 @@ func Test_CreateInterface(t *testing.T) {
}
opts := WGIFaceOpts{
IFaceName: ifaceName,
Address: wgaddr.MustParseWGAddress(wgIP),
Address: wgIP,
WGPort: 33100,
WGPrivKey: key,
MTU: DefaultMTU,
@@ -175,7 +174,7 @@ func Test_Close(t *testing.T) {
opts := WGIFaceOpts{
IFaceName: ifaceName,
Address: wgaddr.MustParseWGAddress(wgIP),
Address: wgIP,
WGPort: wgPort,
WGPrivKey: key,
MTU: DefaultMTU,
@@ -220,7 +219,7 @@ func TestRecreation(t *testing.T) {
opts := WGIFaceOpts{
IFaceName: ifaceName,
Address: wgaddr.MustParseWGAddress(wgIP),
Address: wgIP,
WGPort: wgPort,
WGPrivKey: key,
MTU: DefaultMTU,
@@ -292,7 +291,7 @@ func Test_ConfigureInterface(t *testing.T) {
}
opts := WGIFaceOpts{
IFaceName: ifaceName,
Address: wgaddr.MustParseWGAddress(wgIP),
Address: wgIP,
WGPort: wgPort,
WGPrivKey: key,
MTU: DefaultMTU,
@@ -348,7 +347,7 @@ func Test_UpdatePeer(t *testing.T) {
opts := WGIFaceOpts{
IFaceName: ifaceName,
Address: wgaddr.MustParseWGAddress(wgIP),
Address: wgIP,
WGPort: 33100,
WGPrivKey: key,
MTU: DefaultMTU,
@@ -418,7 +417,7 @@ func Test_RemovePeer(t *testing.T) {
opts := WGIFaceOpts{
IFaceName: ifaceName,
Address: wgaddr.MustParseWGAddress(wgIP),
Address: wgIP,
WGPort: 33100,
WGPrivKey: key,
MTU: DefaultMTU,
@@ -483,7 +482,7 @@ func Test_ConnectPeers(t *testing.T) {
optsPeer1 := WGIFaceOpts{
IFaceName: peer1ifaceName,
Address: wgaddr.MustParseWGAddress(peer1wgIP.String()),
Address: peer1wgIP.String(),
WGPort: peer1wgPort,
WGPrivKey: peer1Key.String(),
MTU: DefaultMTU,
@@ -523,7 +522,7 @@ func Test_ConnectPeers(t *testing.T) {
optsPeer2 := WGIFaceOpts{
IFaceName: peer2ifaceName,
Address: wgaddr.MustParseWGAddress(peer2wgIP.String()),
Address: peer2wgIP.String(),
WGPort: peer2wgPort,
WGPrivKey: peer2Key.String(),
MTU: DefaultMTU,

View File

@@ -13,7 +13,7 @@ import (
const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY"
type NetStackTun struct { //nolint:revive
addresses []netip.Addr
address netip.Addr
dnsAddress netip.Addr
mtu int
listenAddress string
@@ -22,9 +22,9 @@ type NetStackTun struct { //nolint:revive
tundev tun.Device
}
func NewNetStackTun(listenAddress string, addresses []netip.Addr, dnsAddress netip.Addr, mtu int) *NetStackTun {
func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.Addr, mtu int) *NetStackTun {
return &NetStackTun{
addresses: addresses,
address: address,
dnsAddress: dnsAddress,
mtu: mtu,
listenAddress: listenAddress,
@@ -33,7 +33,7 @@ func NewNetStackTun(listenAddress string, addresses []netip.Addr, dnsAddress net
func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
nsTunDev, tunNet, err := netstack.CreateNetTUN(
t.addresses,
[]netip.Addr{t.address},
[]netip.Addr{t.dnsAddress},
t.mtu)
if err != nil {

View File

@@ -3,18 +3,12 @@ package wgaddr
import (
"fmt"
"net/netip"
"github.com/netbirdio/netbird/shared/netiputil"
)
// Address WireGuard parsed address
type Address struct {
IP netip.Addr
Network netip.Prefix
// IPv6 overlay address, if assigned.
IPv6 netip.Addr
IPv6Net netip.Prefix
}
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
@@ -29,60 +23,6 @@ func ParseWGAddress(address string) (Address, error) {
}, nil
}
// HasIPv6 reports whether a v6 overlay address is assigned.
func (addr Address) HasIPv6() bool {
return addr.IPv6.IsValid()
}
func (addr Address) String() string {
return addr.Prefix().String()
}
// IPv6String returns the v6 address in CIDR notation, or empty string if none.
func (addr Address) IPv6String() string {
if !addr.HasIPv6() {
return ""
}
return addr.IPv6Prefix().String()
}
// Prefix returns the v4 host address with its network prefix length (e.g. 100.64.0.1/16).
func (addr Address) Prefix() netip.Prefix {
return netip.PrefixFrom(addr.IP, addr.Network.Bits())
}
// IPv6Prefix returns the v6 host address with its network prefix length, or a zero prefix if none.
func (addr Address) IPv6Prefix() netip.Prefix {
if !addr.HasIPv6() {
return netip.Prefix{}
}
return netip.PrefixFrom(addr.IPv6, addr.IPv6Net.Bits())
}
// SetIPv6FromCompact decodes a compact prefix (5 or 17 bytes) and sets the IPv6 fields.
// Returns an error if the bytes are invalid. A nil or empty input is a no-op.
//
//nolint:recvcheck
func (addr *Address) SetIPv6FromCompact(raw []byte) error {
if len(raw) == 0 {
return nil
}
prefix, err := netiputil.DecodePrefix(raw)
if err != nil {
return fmt.Errorf("decode v6 overlay address: %w", err)
}
if !prefix.Addr().Is6() {
return fmt.Errorf("expected IPv6 address, got %s", prefix.Addr())
}
addr.IPv6 = prefix.Addr()
addr.IPv6Net = prefix.Masked()
return nil
}
// ClearIPv6 removes the IPv6 overlay address, leaving only v4.
//
//nolint:recvcheck
func (addr *Address) ClearIPv6() {
addr.IPv6 = netip.Addr{}
addr.IPv6Net = netip.Prefix{}
return fmt.Sprintf("%s/%d", addr.IP.String(), addr.Network.Bits())
}

View File

@@ -1,10 +0,0 @@
package wgaddr
// MustParseWGAddress parses and returns a WG Address, panicking on error.
func MustParseWGAddress(address string) Address {
a, err := ParseWGAddress(address)
if err != nil {
panic(err)
}
return a
}

View File

@@ -6,7 +6,7 @@ import (
"fmt"
"net"
"net/netip"
"strings"
"sync"
log "github.com/sirupsen/logrus"
@@ -196,25 +196,18 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
}
}
// fakeAddress returns a fake address that is used as an identifier for the peer.
// The fake address is in the format of 127.1.x.x where x.x is derived from the
// last two bytes of the peer address (works for both IPv4 and IPv6).
// fakeAddress returns a fake address that is used to as an identifier for the peer.
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
if peerAddress == nil {
return nil, fmt.Errorf("nil peer address")
}
if peerAddress.Port < 0 || peerAddress.Port > 65535 {
return nil, fmt.Errorf("invalid UDP port: %d", peerAddress.Port)
}
addr, ok := netip.AddrFromSlice(peerAddress.IP)
if !ok {
octets := strings.Split(peerAddress.IP.String(), ".")
if len(octets) != 4 {
return nil, fmt.Errorf("invalid IP format")
}
addr = addr.Unmap()
raw := addr.As16()
fakeIP := netip.AddrFrom4([4]byte{127, 1, raw[14], raw[15]})
fakeIP, err := netip.ParseAddr(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3]))
if err != nil {
return nil, fmt.Errorf("parse new IP: %w", err)
}
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
return &netipAddr, nil

View File

@@ -5,6 +5,7 @@ import (
"encoding/hex"
"errors"
"fmt"
"net"
"net/netip"
"strconv"
"sync"
@@ -18,7 +19,6 @@ import (
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil"
)
var ErrSourceRangesEmpty = errors.New("sources range is empty")
@@ -105,10 +105,6 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
newRulePairs := make(map[id.RuleID][]firewall.Rule)
ipsetByRuleSelectors := make(map[string]string)
// TODO: deny rules should be fatal: if a deny rule fails to apply, we must
// roll back all allow rules to avoid a fail-open where allowed traffic bypasses
// the missing deny. Currently we accumulate errors and continue.
var merr *multierror.Error
for _, r := range rules {
// if this rule is member of rule selection with more than DefaultIPsCountForSet
// it's IP address can be used in the ipset for firewall manager which supports it
@@ -121,8 +117,9 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
}
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("apply firewall rule: %w", err))
continue
log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
d.rollBack(newRulePairs)
break
}
if len(rulePair) > 0 {
d.peerRulesPairs[pairID] = rulePair
@@ -130,10 +127,6 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
}
}
if merr != nil {
log.Errorf("failed to apply %d peer ACL rule(s): %v", merr.Len(), nberrors.FormatErrorOrNil(merr))
}
for pairID, rules := range d.peerRulesPairs {
if _, ok := newRulePairs[pairID]; !ok {
for _, rule := range rules {
@@ -223,9 +216,9 @@ func (d *DefaultManager) protoRuleToFirewallRule(
r *mgmProto.FirewallRule,
ipsetName string,
) (id.RuleID, []firewall.Rule, error) {
ip, err := extractRuleIP(r)
if err != nil {
return "", nil, err
ip := net.ParseIP(r.PeerIP)
if ip == nil {
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
}
protocol, err := convertToFirewallProtocol(r.Protocol)
@@ -296,13 +289,13 @@ func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
func (d *DefaultManager) addInRules(
id []byte,
ip netip.Addr,
ip net.IP,
protocol firewall.Protocol,
port *firewall.Port,
action firewall.Action,
ipsetName string,
) ([]firewall.Rule, error) {
rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, nil, port, action, ipsetName)
rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, nil, port, action, ipsetName)
if err != nil {
return nil, fmt.Errorf("add firewall rule: %w", err)
}
@@ -312,7 +305,7 @@ func (d *DefaultManager) addInRules(
func (d *DefaultManager) addOutRules(
id []byte,
ip netip.Addr,
ip net.IP,
protocol firewall.Protocol,
port *firewall.Port,
action firewall.Action,
@@ -322,7 +315,7 @@ func (d *DefaultManager) addOutRules(
return nil, nil
}
rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, port, nil, action, ipsetName)
rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, port, nil, action, ipsetName)
if err != nil {
return nil, fmt.Errorf("add firewall rule: %w", err)
}
@@ -330,9 +323,9 @@ func (d *DefaultManager) addOutRules(
return rule, nil
}
// getPeerRuleID returns unique ID for the rule based on its parameters.
// getPeerRuleID() returns unique ID for the rule based on its parameters.
func (d *DefaultManager) getPeerRuleID(
ip netip.Addr,
ip net.IP,
proto firewall.Protocol,
direction int,
port *firewall.Port,
@@ -351,25 +344,15 @@ func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) st
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)
}
// extractRuleIP extracts the peer IP from a firewall rule.
// If sourcePrefixes is populated (new management), decode the first entry and use its address.
// Otherwise fall back to the deprecated PeerIP string field (old management).
func extractRuleIP(r *mgmProto.FirewallRule) (netip.Addr, error) {
if len(r.SourcePrefixes) > 0 {
addr, err := netiputil.DecodeAddr(r.SourcePrefixes[0])
if err != nil {
return netip.Addr{}, fmt.Errorf("decode source prefix: %w", err)
func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) {
log.Debugf("rollback ACL to previous state")
for _, rules := range newRulePairs {
for _, rule := range rules {
if err := d.firewall.DeletePeerRule(rule); err != nil {
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.ID(), err)
}
}
return addr.Unmap(), nil
}
//nolint:staticcheck // PeerIP used for backward compatibility with old management
addr, err := netip.ParseAddr(r.PeerIP)
if err != nil {
return netip.Addr{}, fmt.Errorf("invalid IP address, skipping firewall rule")
}
return addr.Unmap(), nil
}
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) {

View File

@@ -321,7 +321,6 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
a.config.DisableFirewall,
a.config.BlockLANAccess,
a.config.BlockInbound,
a.config.DisableIPv6,
a.config.LazyConnectionEnabled,
a.config.EnableSSHRoot,
a.config.EnableSSHSFTP,

View File

@@ -14,13 +14,10 @@ import (
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
@@ -539,20 +536,9 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
if config.NetworkMonitor != nil {
nm = *config.NetworkMonitor
}
wgAddr, err := wgaddr.ParseWGAddress(peerConfig.Address)
if err != nil {
return nil, fmt.Errorf("parse overlay address %q: %w", peerConfig.Address, err)
}
if !config.DisableIPv6 {
if err := wgAddr.SetIPv6FromCompact(peerConfig.GetAddressV6()); err != nil {
log.Warn(err)
}
}
engineConf := &EngineConfig{
WgIfaceName: config.WgIface,
WgAddr: wgAddr,
WgAddr: peerConfig.Address,
IFaceBlackList: config.IFaceBlackList,
DisableIPv6Discovery: config.DisableIPv6Discovery,
WgPrivateKey: key,
@@ -577,7 +563,6 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
DisableFirewall: config.DisableFirewall,
BlockLANAccess: config.BlockLANAccess,
BlockInbound: config.BlockInbound,
DisableIPv6: config.DisableIPv6,
LazyConnectionEnabled: config.LazyConnectionEnabled,
@@ -652,7 +637,6 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
config.DisableFirewall,
config.BlockLANAccess,
config.BlockInbound,
config.DisableIPv6,
config.LazyConnectionEnabled,
config.EnableSSHRoot,
config.EnableSSHSFTP,

View File

@@ -40,10 +40,6 @@ func (noopNetworkChangeListener) SetInterfaceIP(string) {
// network stack, not by OS-level interface configuration.
}
func (noopNetworkChangeListener) SetInterfaceIPv6(string) {
// No-op: same as SetInterfaceIP, IPv6 overlay is managed by userspace stack.
}
// noopDnsReadyListener is a stub for embed.Client on Android.
// DNS readiness notifications are not needed in netstack/embed mode
// since system DNS is disabled and DNS resolution happens externally.

View File

@@ -21,7 +21,6 @@ import (
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/protobuf/encoding/protojson"
"github.com/netbirdio/netbird/client/anonymize"
@@ -31,7 +30,6 @@ import (
"github.com/netbirdio/netbird/client/internal/updater/installer"
nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil"
)
const readmeContent = `Netbird debug bundle
@@ -585,9 +583,6 @@ func isSensitiveEnvVar(key string) bool {
func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) {
configContent.WriteString("NetBird Client Configuration:\n\n")
if key, err := wgtypes.ParseKey(g.internalConfig.PrivateKey); err == nil {
configContent.WriteString(fmt.Sprintf("PublicKey: %s\n", key.PublicKey().String()))
}
configContent.WriteString(fmt.Sprintf("WgIface: %s\n", g.internalConfig.WgIface))
configContent.WriteString(fmt.Sprintf("WgPort: %d\n", g.internalConfig.WgPort))
if g.internalConfig.NetworkMonitor != nil {
@@ -612,12 +607,6 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
if g.internalConfig.EnableSSHRemotePortForwarding != nil {
configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding))
}
if g.internalConfig.DisableSSHAuth != nil {
configContent.WriteString(fmt.Sprintf("DisableSSHAuth: %v\n", *g.internalConfig.DisableSSHAuth))
}
if g.internalConfig.SSHJWTCacheTTL != nil {
configContent.WriteString(fmt.Sprintf("SSHJWTCacheTTL: %d\n", *g.internalConfig.SSHJWTCacheTTL))
}
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
@@ -625,7 +614,6 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
configContent.WriteString(fmt.Sprintf("DisableFirewall: %v\n", g.internalConfig.DisableFirewall))
configContent.WriteString(fmt.Sprintf("BlockLANAccess: %v\n", g.internalConfig.BlockLANAccess))
configContent.WriteString(fmt.Sprintf("BlockInbound: %v\n", g.internalConfig.BlockInbound))
configContent.WriteString(fmt.Sprintf("DisableIPv6: %v\n", g.internalConfig.DisableIPv6))
if g.internalConfig.DisableNotifications != nil {
configContent.WriteString(fmt.Sprintf("DisableNotifications: %v\n", *g.internalConfig.DisableNotifications))
@@ -645,7 +633,6 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
}
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
configContent.WriteString(fmt.Sprintf("MTU: %d\n", g.internalConfig.MTU))
}
func (g *BundleGenerator) addProf() (err error) {
@@ -1296,21 +1283,6 @@ func anonymizePeerConfig(config *mgmProto.PeerConfig, anonymizer *anonymize.Anon
config.Address = anonymizer.AnonymizeIP(addr).String()
}
if len(config.GetAddressV6()) > 0 {
v6Prefix, err := netiputil.DecodePrefix(config.GetAddressV6())
if err != nil {
config.AddressV6 = nil
} else {
anonV6 := anonymizer.AnonymizeIP(v6Prefix.Addr())
b, err := netiputil.EncodePrefix(netip.PrefixFrom(anonV6, v6Prefix.Bits()))
if err != nil {
config.AddressV6 = nil
} else {
config.AddressV6 = b
}
}
}
anonymizeSSHConfig(config.SshConfig)
config.Dns = anonymizer.AnonymizeString(config.Dns)
@@ -1413,20 +1385,8 @@ func anonymizeFirewallRule(rule *mgmProto.FirewallRule, anonymizer *anonymize.An
return
}
//nolint:staticcheck // PeerIP used for backward compatibility
if addr, err := netip.ParseAddr(rule.PeerIP); err == nil {
rule.PeerIP = anonymizer.AnonymizeIP(addr).String() //nolint:staticcheck
}
for i, raw := range rule.GetSourcePrefixes() {
p, err := netiputil.DecodePrefix(raw)
if err != nil {
continue
}
anonAddr := anonymizer.AnonymizeIP(p.Addr())
if b, err := netiputil.EncodePrefix(netip.PrefixFrom(anonAddr, p.Bits())); err == nil {
rule.SourcePrefixes[i] = b
}
rule.PeerIP = anonymizer.AnonymizeIP(addr).String()
}
}

View File

@@ -5,33 +5,19 @@ import (
"bytes"
"encoding/json"
"net"
"net/netip"
"net/url"
"os"
"path/filepath"
"reflect"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/configs"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil"
)
func mustEncodePrefix(t *testing.T, p netip.Prefix) []byte {
t.Helper()
b, err := netiputil.EncodePrefix(p)
require.NoError(t, err)
return b
}
func TestAnonymizeStateFile(t *testing.T) {
testState := map[string]json.RawMessage{
"null_state": json.RawMessage("null"),
@@ -182,7 +168,7 @@ func TestAnonymizeStateFile(t *testing.T) {
assert.Equal(t, "100.64.0.1", state["protected_ip"]) // Protected IP unchanged
assert.Equal(t, "8.8.8.8", state["well_known_ip"]) // Well-known IP unchanged
assert.NotEqual(t, "2001:db8::1", state["ipv6_addr"])
assert.NotEqual(t, "fd00::1", state["private_ipv6"]) // ULA IPv6 anonymized (global ID is a fingerprint)
assert.Equal(t, "fd00::1", state["private_ipv6"]) // Private IPv6 unchanged
assert.NotEqual(t, "test.example.com", state["domain"])
assert.True(t, strings.HasSuffix(state["domain"].(string), ".domain"))
assert.Equal(t, "device.netbird.cloud", state["netbird_domain"]) // Netbird domain unchanged
@@ -286,13 +272,11 @@ func mustMarshal(v any) json.RawMessage {
}
func TestAnonymizeNetworkMap(t *testing.T) {
origV6Prefix := netip.MustParsePrefix("2001:db8:abcd::5/64")
networkMap := &mgmProto.NetworkMap{
PeerConfig: &mgmProto.PeerConfig{
Address: "203.0.113.5",
AddressV6: mustEncodePrefix(t, origV6Prefix),
Dns: "1.2.3.4",
Fqdn: "peer1.corp.example.com",
Address: "203.0.113.5",
Dns: "1.2.3.4",
Fqdn: "peer1.corp.example.com",
SshConfig: &mgmProto.SSHConfig{
SshPubKey: []byte("ssh-rsa AAAAB3NzaC1..."),
},
@@ -366,12 +350,6 @@ func TestAnonymizeNetworkMap(t *testing.T) {
require.NotEqual(t, "peer1.corp.example.com", peerCfg.Fqdn)
require.True(t, strings.HasSuffix(peerCfg.Fqdn, ".domain"))
// Verify AddressV6 is anonymized but preserves prefix length
anonV6Prefix, err := netiputil.DecodePrefix(peerCfg.AddressV6)
require.NoError(t, err)
assert.Equal(t, origV6Prefix.Bits(), anonV6Prefix.Bits(), "prefix length must be preserved")
assert.NotEqual(t, origV6Prefix.Addr(), anonV6Prefix.Addr(), "IPv6 address must be anonymized")
// Verify SSH key is replaced
require.Equal(t, []byte("ssh-placeholder-key"), peerCfg.SshConfig.SshPubKey)
@@ -493,8 +471,8 @@ func TestSanitizeServiceEnvVars(t *testing.T) {
anonymize: false,
input: map[string]any{
jsonKeyServiceEnv: map[string]any{
"HOME": "/root",
"PATH": "/usr/bin",
"HOME": "/root",
"PATH": "/usr/bin",
"NB_LOG_LEVEL": "debug",
},
},
@@ -511,9 +489,9 @@ func TestSanitizeServiceEnvVars(t *testing.T) {
anonymize: false,
input: map[string]any{
jsonKeyServiceEnv: map[string]any{
"NB_SETUP_KEY": "abc123",
"NB_API_TOKEN": "tok_xyz",
"NB_LOG_LEVEL": "info",
"NB_SETUP_KEY": "abc123",
"NB_API_TOKEN": "tok_xyz",
"NB_LOG_LEVEL": "info",
},
},
check: func(t *testing.T, params map[string]any) {
@@ -677,6 +655,8 @@ func isInCGNATRange(ip net.IP) bool {
}
func TestAnonymizeFirewallRules(t *testing.T) {
// TODO: Add ipv6
// Example iptables-save output
iptablesSave := `# Generated by iptables-save v1.8.7 on Thu Dec 19 10:00:00 2024
*filter
@@ -712,31 +692,17 @@ Chain FORWARD (policy ACCEPT 0 packets, 0 bytes)
Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes)
pkts bytes target prot opt in out source destination`
// Example ip6tables-save output
ip6tablesSave := `# Generated by ip6tables-save v1.8.7 on Thu Dec 19 10:00:00 2024
*filter
:INPUT ACCEPT [0:0]
:FORWARD ACCEPT [0:0]
:OUTPUT ACCEPT [0:0]
-A INPUT -s fd00:1234::1/128 -j ACCEPT
-A INPUT -s 2607:f8b0:4005::1/128 -j DROP
-A FORWARD -s 2001:db8::/32 -d 2607:f8b0:4005::200e/128 -j ACCEPT
COMMIT`
// Example nftables output with IPv6
// Example nftables output
nftablesRules := `table inet filter {
chain input {
type filter hook input priority filter; policy accept;
ip saddr 192.168.1.1 accept
ip saddr 44.192.140.1 drop
ip6 saddr 2607:f8b0:4005::1 drop
ip6 saddr fd00:1234::1 accept
}
chain forward {
type filter hook forward priority filter; policy accept;
ip saddr 10.0.0.0/8 drop
ip saddr 44.192.140.0/24 ip daddr 52.84.12.34/24 accept
ip6 saddr 2001:db8::/32 ip6 daddr 2607:f8b0:4005::200e/128 accept
}
}`
@@ -799,159 +765,4 @@ COMMIT`
assert.Contains(t, anonNftables, "table inet filter {")
assert.Contains(t, anonNftables, "chain input {")
assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;")
// IPv6 public addresses in nftables should be anonymized
assert.NotContains(t, anonNftables, "2607:f8b0:4005::1")
assert.NotContains(t, anonNftables, "2607:f8b0:4005::200e")
assert.NotContains(t, anonNftables, "2001:db8::")
assert.Contains(t, anonNftables, "2001:db8:ffff::") // Default anonymous v6 range
// ULA addresses in nftables should be anonymized (global ID is a fingerprint)
assert.NotContains(t, anonNftables, "fd00:1234::1")
// IPv6 nftables structure preserved
assert.Contains(t, anonNftables, "ip6 saddr")
assert.Contains(t, anonNftables, "ip6 daddr")
// Test ip6tables-save anonymization
anonIp6tablesSave := anonymizer.AnonymizeString(ip6tablesSave)
// ULA IPv6 should be anonymized (global ID is a fingerprint)
assert.NotContains(t, anonIp6tablesSave, "fd00:1234::1/128")
// Public IPv6 addresses should be anonymized
assert.NotContains(t, anonIp6tablesSave, "2607:f8b0:4005::1")
assert.NotContains(t, anonIp6tablesSave, "2607:f8b0:4005::200e")
assert.NotContains(t, anonIp6tablesSave, "2001:db8::")
assert.Contains(t, anonIp6tablesSave, "2001:db8:ffff::") // Default anonymous v6 range
// Structure should be preserved
assert.Contains(t, anonIp6tablesSave, "*filter")
assert.Contains(t, anonIp6tablesSave, "COMMIT")
assert.Contains(t, anonIp6tablesSave, "-j DROP")
assert.Contains(t, anonIp6tablesSave, "-j ACCEPT")
}
// TestAddConfig_AllFieldsCovered uses reflection to ensure every field in
// profilemanager.Config is either rendered in the debug bundle or explicitly
// excluded. When a new field is added to Config, this test fails until the
// developer either dumps it in addConfig/addCommonConfigFields or adds it to
// the excluded set with a justification.
func TestAddConfig_AllFieldsCovered(t *testing.T) {
excluded := map[string]string{
"PrivateKey": "sensitive: WireGuard private key",
"PreSharedKey": "sensitive: WireGuard pre-shared key",
"SSHKey": "sensitive: SSH private key",
"ClientCertKeyPair": "non-config: parsed cert pair, not serialized",
}
mURL, _ := url.Parse("https://api.example.com:443")
aURL, _ := url.Parse("https://admin.example.com:443")
bTrue := true
iVal := 42
cfg := &profilemanager.Config{
PrivateKey: "priv",
PreSharedKey: "psk",
ManagementURL: mURL,
AdminURL: aURL,
WgIface: "wt0",
WgPort: 51820,
NetworkMonitor: &bTrue,
IFaceBlackList: []string{"eth0"},
DisableIPv6Discovery: true,
RosenpassEnabled: true,
RosenpassPermissive: true,
ServerSSHAllowed: &bTrue,
EnableSSHRoot: &bTrue,
EnableSSHSFTP: &bTrue,
EnableSSHLocalPortForwarding: &bTrue,
EnableSSHRemotePortForwarding: &bTrue,
DisableSSHAuth: &bTrue,
SSHJWTCacheTTL: &iVal,
DisableClientRoutes: true,
DisableServerRoutes: true,
DisableDNS: true,
DisableFirewall: true,
BlockLANAccess: true,
BlockInbound: true,
DisableNotifications: &bTrue,
DNSLabels: domain.List{},
SSHKey: "sshkey",
NATExternalIPs: []string{"1.2.3.4"},
CustomDNSAddress: "1.1.1.1:53",
DisableAutoConnect: true,
DNSRouteInterval: 5 * time.Second,
ClientCertPath: "/tmp/cert",
ClientCertKeyPath: "/tmp/key",
LazyConnectionEnabled: true,
MTU: 1280,
}
for _, anonymize := range []bool{false, true} {
t.Run("anonymize="+map[bool]string{true: "true", false: "false"}[anonymize], func(t *testing.T) {
g := &BundleGenerator{
anonymizer: newAnonymizerForTest(),
internalConfig: cfg,
anonymize: anonymize,
}
var sb strings.Builder
g.addCommonConfigFields(&sb)
rendered := sb.String() + renderAddConfigSpecific(g)
val := reflect.ValueOf(cfg).Elem()
typ := val.Type()
var missing []string
for i := 0; i < typ.NumField(); i++ {
name := typ.Field(i).Name
if _, ok := excluded[name]; ok {
continue
}
if !strings.Contains(rendered, name+":") {
missing = append(missing, name)
}
}
if len(missing) > 0 {
t.Fatalf("Config field(s) not present in debug bundle output: %v\n"+
"Either render the field in addCommonConfigFields/addConfig, "+
"or add it to the excluded map with a justification.", missing)
}
})
}
}
// renderAddConfigSpecific renders the fields handled by the anonymize/non-anonymize
// branches in addConfig (ManagementURL, AdminURL, NATExternalIPs, CustomDNSAddress).
// addCommonConfigFields covers the rest. Keeping this in the test mirrors the
// production shape without needing to write an actual zip.
func renderAddConfigSpecific(g *BundleGenerator) string {
var sb strings.Builder
if g.anonymize {
if g.internalConfig.ManagementURL != nil {
sb.WriteString("ManagementURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.ManagementURL.String()) + "\n")
}
if g.internalConfig.AdminURL != nil {
sb.WriteString("AdminURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.AdminURL.String()) + "\n")
}
sb.WriteString("NATExternalIPs: x\n")
if g.internalConfig.CustomDNSAddress != "" {
sb.WriteString("CustomDNSAddress: " + g.anonymizer.AnonymizeString(g.internalConfig.CustomDNSAddress) + "\n")
}
} else {
if g.internalConfig.ManagementURL != nil {
sb.WriteString("ManagementURL: " + g.internalConfig.ManagementURL.String() + "\n")
}
if g.internalConfig.AdminURL != nil {
sb.WriteString("AdminURL: " + g.internalConfig.AdminURL.String() + "\n")
}
sb.WriteString("NATExternalIPs: x\n")
if g.internalConfig.CustomDNSAddress != "" {
sb.WriteString("CustomDNSAddress: " + g.internalConfig.CustomDNSAddress + "\n")
}
}
return sb.String()
}
func newAnonymizerForTest() *anonymize.Anonymizer {
return anonymize.NewAnonymizer(anonymize.DefaultAddresses())
}

View File

@@ -12,83 +12,52 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
)
func createPTRRecord(record nbdns.SimpleRecord, prefix netip.Prefix) (nbdns.SimpleRecord, bool) {
ip, err := netip.ParseAddr(record.RData)
func createPTRRecord(aRecord nbdns.SimpleRecord, prefix netip.Prefix) (nbdns.SimpleRecord, bool) {
ip, err := netip.ParseAddr(aRecord.RData)
if err != nil {
log.Warnf("failed to parse IP address %s: %v", record.RData, err)
log.Warnf("failed to parse IP address %s: %v", aRecord.RData, err)
return nbdns.SimpleRecord{}, false
}
ip = ip.Unmap()
if !prefix.Contains(ip) {
return nbdns.SimpleRecord{}, false
}
var rdnsName string
if ip.Is4() {
octets := strings.Split(ip.String(), ".")
slices.Reverse(octets)
rdnsName = dns.Fqdn(strings.Join(octets, ".") + ".in-addr.arpa")
} else {
// Expand to full 32 nibbles in reverse order (LSB first) per RFC 3596.
raw := ip.As16()
nibbles := make([]string, 32)
for i := 0; i < 16; i++ {
nibbles[31-i*2] = fmt.Sprintf("%x", raw[i]>>4)
nibbles[31-i*2-1] = fmt.Sprintf("%x", raw[i]&0x0f)
}
rdnsName = dns.Fqdn(strings.Join(nibbles, ".") + ".ip6.arpa")
}
ipOctets := strings.Split(ip.String(), ".")
slices.Reverse(ipOctets)
rdnsName := dns.Fqdn(strings.Join(ipOctets, ".") + ".in-addr.arpa")
return nbdns.SimpleRecord{
Name: rdnsName,
Type: int(dns.TypePTR),
Class: record.Class,
TTL: record.TTL,
RData: dns.Fqdn(record.Name),
Class: aRecord.Class,
TTL: aRecord.TTL,
RData: dns.Fqdn(aRecord.Name),
}, true
}
// generateReverseZoneName creates the reverse DNS zone name for a given network.
// For IPv4 it produces an in-addr.arpa name, for IPv6 an ip6.arpa name.
// generateReverseZoneName creates the reverse DNS zone name for a given network
func generateReverseZoneName(network netip.Prefix) (string, error) {
networkIP := network.Masked().Addr().Unmap()
bits := network.Bits()
networkIP := network.Masked().Addr()
if networkIP.Is4() {
// Round up to nearest byte.
octetsToUse := (bits + 7) / 8
octets := strings.Split(networkIP.String(), ".")
if octetsToUse > len(octets) {
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", bits)
}
reverseOctets := make([]string, octetsToUse)
for i := 0; i < octetsToUse; i++ {
reverseOctets[octetsToUse-1-i] = octets[i]
}
return dns.Fqdn(strings.Join(reverseOctets, ".") + ".in-addr.arpa"), nil
if !networkIP.Is4() {
return "", fmt.Errorf("reverse DNS is only supported for IPv4 networks, got: %s", networkIP)
}
// IPv6: round up to nearest nibble (4-bit boundary).
nibblesToUse := (bits + 3) / 4
// round up to nearest byte
octetsToUse := (network.Bits() + 7) / 8
raw := networkIP.As16()
allNibbles := make([]string, 32)
for i := 0; i < 16; i++ {
allNibbles[i*2] = fmt.Sprintf("%x", raw[i]>>4)
allNibbles[i*2+1] = fmt.Sprintf("%x", raw[i]&0x0f)
octets := strings.Split(networkIP.String(), ".")
if octetsToUse > len(octets) {
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", network.Bits())
}
// Take the first nibblesToUse nibbles (network portion), reverse them.
used := make([]string, nibblesToUse)
for i := 0; i < nibblesToUse; i++ {
used[nibblesToUse-1-i] = allNibbles[i]
reverseOctets := make([]string, octetsToUse)
for i := 0; i < octetsToUse; i++ {
reverseOctets[octetsToUse-1-i] = octets[i]
}
return dns.Fqdn(strings.Join(used, ".") + ".ip6.arpa"), nil
return dns.Fqdn(strings.Join(reverseOctets, ".") + ".in-addr.arpa"), nil
}
// zoneExists checks if a zone with the given name already exists in the configuration
@@ -102,7 +71,7 @@ func zoneExists(config *nbdns.Config, zoneName string) bool {
return false
}
// collectPTRRecords gathers all PTR records for the given network from A and AAAA records.
// collectPTRRecords gathers all PTR records for the given network from A records
func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.SimpleRecord {
var records []nbdns.SimpleRecord
@@ -111,7 +80,7 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple
continue
}
for _, record := range zone.Records {
if record.Type != int(dns.TypeA) && record.Type != int(dns.TypeAAAA) {
if record.Type != int(dns.TypeA) {
continue
}

View File

@@ -0,0 +1,63 @@
package dnsfw
import (
"os"
"strconv"
"strings"
log "github.com/sirupsen/logrus"
)
const (
// EnvDisable disables the DNS firewall entirely when set to a truthy value.
EnvDisable = "NB_DISABLE_DNS_FIREWALL"
// EnvPorts overrides the comma-separated list of remote ports to block.
// Empty disables the firewall.
EnvPorts = "NB_DNS_FIREWALL_PORTS"
// EnvStrict enables strict mode: permit DNS only to the virtual DNS IP
// and the netbird daemon. Default mode also permits anything on the
// netbird tunnel interface, which is safer if NRPT is silently ignored
// by Windows but lets apps reach custom DNS servers via the tunnel.
EnvStrict = "NB_DNS_FIREWALL_STRICT"
)
// defaultBlockedPorts are the well-known DNS ports we block for non-netbird
// processes: 53 (plain DNS) and 853 (DNS-over-TLS).
var defaultBlockedPorts = []uint16{53, 853}
// blockedPorts returns the effective port list, honoring env overrides.
// A nil return means the firewall should not be installed.
func blockedPorts() []uint16 {
if disabled, _ := strconv.ParseBool(os.Getenv(EnvDisable)); disabled {
log.Infof("dns firewall disabled via %s", EnvDisable)
return nil
}
override, ok := os.LookupEnv(EnvPorts)
if !ok {
return defaultBlockedPorts
}
var ports []uint16
for _, raw := range strings.Split(override, ",") {
raw = strings.TrimSpace(raw)
if raw == "" {
continue
}
port, err := strconv.ParseUint(raw, 10, 16)
if err != nil {
log.Warnf("dns firewall: ignoring invalid port %q in %s: %v", raw, EnvPorts, err)
continue
}
if port == 0 {
log.Warnf("dns firewall: ignoring port 0 in %s", EnvPorts)
continue
}
ports = append(ports, uint16(port))
}
if len(ports) == 0 {
log.Infof("dns firewall disabled: %s yielded no valid ports", EnvPorts)
return nil
}
return ports
}

View File

@@ -0,0 +1,39 @@
package dnsfw
import (
"reflect"
"testing"
)
func TestBlockedPorts(t *testing.T) {
tests := []struct {
name string
disable string
ports string
setPorts bool
want []uint16
}{
{name: "default", want: defaultBlockedPorts},
{name: "disabled", disable: "true", want: nil},
{name: "disabled false keeps default", disable: "false", want: defaultBlockedPorts},
{name: "override single port", ports: "53", setPorts: true, want: []uint16{53}},
{name: "override multi", ports: "53, 853 ,5353", setPorts: true, want: []uint16{53, 853, 5353}},
{name: "override empty disables", ports: "", setPorts: true, want: nil},
{name: "override invalid skipped", ports: "53,not-a-port,853", setPorts: true, want: []uint16{53, 853}},
{name: "override zero skipped", ports: "53,0,853", setPorts: true, want: []uint16{53, 853}},
{name: "override only invalid disables", ports: "abc", setPorts: true, want: nil},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Setenv(EnvDisable, tc.disable)
if tc.setPorts {
t.Setenv(EnvPorts, tc.ports)
}
got := blockedPorts()
if !reflect.DeepEqual(got, tc.want) {
t.Fatalf("blockedPorts() = %v, want %v", got, tc.want)
}
})
}
}

View File

@@ -0,0 +1,16 @@
// Package dnsfw blocks DNS traffic from non-netbird processes when netbird is
// managing the host's DNS, so that resolvers running on apps or libraries
// outside netbird cannot bypass the configured DNS path.
//
// Implementation is Windows-only (uses WFP). On other platforms New returns
// a no-op manager.
package dnsfw
import "net/netip"
// Manager controls the per-tunnel DNS firewall. Both methods must be safe
// to call multiple times.
type Manager interface {
Enable(ifaceGUID string, virtualDNSIP netip.Addr) error
Disable() error
}

View File

@@ -0,0 +1,15 @@
//go:build !windows
package dnsfw
import "net/netip"
type noopManager struct{}
func (noopManager) Enable(string, netip.Addr) error { return nil }
func (noopManager) Disable() error { return nil }
// New returns a no-op manager on non-Windows platforms.
func New() Manager {
return noopManager{}
}

View File

@@ -0,0 +1,144 @@
//go:build windows
package dnsfw
import (
"fmt"
"net/netip"
"os"
"strconv"
"sync"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
var (
modIphlpapi = windows.NewLazyDLL("iphlpapi.dll")
procConvertInterfaceGuidToLuid = modIphlpapi.NewProc("ConvertInterfaceGuidToLuid")
)
type windowsManager struct {
mu sync.Mutex
// session is the WFP engine handle. Zero when disabled.
session uintptr
}
// Enable installs the dns firewall. Strict mode propagates failures;
// non-strict mode logs and returns nil so partial protection is preserved.
func (m *windowsManager) Enable(ifaceGUID string, virtualDNSIP netip.Addr) error {
m.mu.Lock()
defer m.mu.Unlock()
ports := blockedPorts()
if len(ports) == 0 {
return nil
}
if m.session != 0 {
if err := m.disableLocked(); err != nil {
return fmt.Errorf("reset existing dns firewall session: %w", err)
}
}
strict := strictMode()
luid, err := luidFromGUID(ifaceGUID)
if err != nil {
return m.failOrLog(strict, fmt.Errorf("resolve tun luid from guid %s: %w", ifaceGUID, err))
}
exe, err := os.Executable()
if err != nil {
return m.failOrLog(strict, fmt.Errorf("resolve daemon executable path: %w", err))
}
cfg := installConfig{
tunLUID: luid,
daemonExe: exe,
blockedPorts: ports,
strict: strict,
virtualDNSIP: virtualDNSIP,
}
// session==0 signals a hard failure; non-zero with non-nil err is a partial install.
session, installErr := installFilters(cfg)
if session == 0 {
return m.failOrLog(strict, fmt.Errorf("install dns firewall filters: %w", installErr))
}
if installErr != nil && strict {
_ = closeSession(session)
return fmt.Errorf("strict dns firewall: partial install: %w", installErr)
}
m.session = session
log.Infof("dns firewall installed: iface=%s daemon=%s ports=%v strict=%v virtual_dns=%s",
ifaceGUID, exe, ports, strict, virtualDNSIP)
if installErr != nil {
log.Warnf("dns firewall partially installed (some filters failed): %v", installErr)
}
return nil
}
func (m *windowsManager) Disable() error {
m.mu.Lock()
defer m.mu.Unlock()
return m.disableLocked()
}
func (m *windowsManager) disableLocked() error {
if m.session == 0 {
return nil
}
session := m.session
m.session = 0
if err := closeSession(session); err != nil {
return fmt.Errorf("close wfp session: %w", err)
}
log.Info("dns firewall removed")
return nil
}
// failOrLog returns err unchanged in strict mode. In non-strict mode the
// error is logged and nil is returned.
func (m *windowsManager) failOrLog(strict bool, err error) error {
if strict {
return err
}
log.Errorf("dns firewall: %v", err)
return nil
}
// New returns a Windows DNS firewall manager backed by WFP.
func New() Manager {
return &windowsManager{}
}
// strictMode reports whether strict mode is enabled via env.
func strictMode() bool {
v, _ := strconv.ParseBool(os.Getenv(EnvStrict))
return v
}
// luidFromGUID converts a Windows interface GUID string to its LUID.
func luidFromGUID(ifaceGUID string) (luid uint64, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic in luidFromGUID: %v", r)
}
}()
guid, err := windows.GUIDFromString(ifaceGUID)
if err != nil {
return 0, fmt.Errorf("parse guid: %w", err)
}
rc, _, _ := procConvertInterfaceGuidToLuid.Call(
uintptr(unsafe.Pointer(&guid)),
uintptr(unsafe.Pointer(&luid)),
)
if rc != 0 {
return 0, fmt.Errorf("ConvertInterfaceGuidToLuid returned %d", rc)
}
return luid, nil
}

View File

@@ -0,0 +1,72 @@
//go:build windows
package dnsfw
import (
"net/netip"
"os"
"testing"
)
func TestStrictMode(t *testing.T) {
tests := []struct {
name string
val string
set bool
want bool
}{
{name: "unset", want: false},
{name: "true", val: "true", set: true, want: true},
{name: "1", val: "1", set: true, want: true},
{name: "false", val: "false", set: true, want: false},
{name: "invalid is false", val: "garbage", set: true, want: false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Setenv(EnvStrict, tc.val)
if !tc.set {
os.Unsetenv(EnvStrict)
}
if got := strictMode(); got != tc.want {
t.Fatalf("strictMode() = %v, want %v", got, tc.want)
}
})
}
}
func TestWindowsManagerDisableIdempotent(t *testing.T) {
m := &windowsManager{}
if err := m.Disable(); err != nil {
t.Fatalf("first Disable on fresh manager: %v", err)
}
if err := m.Disable(); err != nil {
t.Fatalf("second Disable on fresh manager: %v", err)
}
if m.session != 0 {
t.Fatalf("session should remain zero, got %d", m.session)
}
}
func TestWindowsManagerEnableNoOpWhenDisabledByEnv(t *testing.T) {
t.Setenv(EnvDisable, "true")
m := &windowsManager{}
if err := m.Enable("00000000-0000-0000-0000-000000000000", netip.Addr{}); err != nil {
t.Fatalf("Enable should be a no-op when firewall disabled by env: %v", err)
}
if m.session != 0 {
t.Fatalf("session must remain zero when env disables firewall, got %d", m.session)
}
}
func TestWindowsManagerEnableNoOpWhenPortsEmpty(t *testing.T) {
t.Setenv(EnvPorts, "")
m := &windowsManager{}
if err := m.Enable("00000000-0000-0000-0000-000000000000", netip.Addr{}); err != nil {
t.Fatalf("Enable should be a no-op when ports list is empty: %v", err)
}
if m.session != 0 {
t.Fatalf("session must remain zero when ports list is empty, got %d", m.session)
}
}

View File

@@ -0,0 +1,53 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*
* Adapted from wireguard-windows tunnel/firewall/helpers.go.
*/
package dnsfw
import (
"errors"
"fmt"
"runtime"
"syscall"
"golang.org/x/sys/windows"
)
func createWtFwpmDisplayData0(name, description string) (*wtFwpmDisplayData0, error) {
namePtr, err := windows.UTF16PtrFromString(name)
if err != nil {
return nil, wrapErr(err)
}
descriptionPtr, err := windows.UTF16PtrFromString(description)
if err != nil {
return nil, wrapErr(err)
}
return &wtFwpmDisplayData0{
name: namePtr,
description: descriptionPtr,
}, nil
}
func filterWeight(weight uint8) wtFwpValue0 {
return wtFwpValue0{
_type: cFWP_UINT8,
value: uintptr(weight),
}
}
func wrapErr(err error) error {
var errno syscall.Errno
if !errors.As(err, &errno) {
return err
}
_, file, line, ok := runtime.Caller(1)
if !ok {
return fmt.Errorf("wfp error at unknown location: %w", err)
}
return fmt.Errorf("wfp error at %s:%d: %w", file, line, err)
}

View File

@@ -0,0 +1,249 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2026 NetBird GmbH. All Rights Reserved.
*
* Filter installers adapted from wireguard-windows tunnel/firewall/rules.go.
* The block-DNS approach (port 53 + UDP/TCP) matches what wireguard-windows
* uses for its kill-switch DNS leak protection. We extend it with a
* configurable port set so we also cover :853 (DoT) and any future ports.
*/
package dnsfw
import (
"encoding/binary"
"fmt"
"net/netip"
"unsafe"
"github.com/hashicorp/go-multierror"
"golang.org/x/sys/windows"
nberrors "github.com/netbirdio/netbird/client/errors"
)
// Filters install at outbound ALE_AUTH_CONNECT layers only; inbound replies
// follow the authorized outbound flow.
// permitTunInterface installs a permit filter for any traffic whose local
// interface is the netbird tunnel.
func permitTunInterface(session uintptr, base *baseObjects, weight uint8, ifLUID uint64) error {
cond := wtFwpmFilterCondition0{
fieldKey: cFWPM_CONDITION_IP_LOCAL_INTERFACE,
matchType: cFWP_MATCH_EQUAL,
conditionValue: wtFwpConditionValue0{
_type: cFWP_UINT64,
value: uintptr(unsafe.Pointer(&ifLUID)),
},
}
filter := wtFwpmFilter0{
providerKey: &base.provider,
subLayerKey: base.filters,
weight: filterWeight(weight),
numFilterConditions: 1,
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&cond)),
action: wtFwpmAction0{_type: cFWP_ACTION_PERMIT},
}
return addOutboundFilters(session, &filter, "Permit netbird tunnel")
}
// permitDaemonByAppID installs a permit filter matching the netbird daemon
// executable by App-ID. App-ID alone is sufficient because netbird.exe is a
// dedicated binary.
func permitDaemonByAppID(session uintptr, base *baseObjects, daemonExe string, weight uint8) error {
appID, err := daemonAppID(daemonExe)
if err != nil {
return err
}
defer fwpmFreeMemory0(unsafe.Pointer(&appID))
cond := wtFwpmFilterCondition0{
fieldKey: cFWPM_CONDITION_ALE_APP_ID,
matchType: cFWP_MATCH_EQUAL,
conditionValue: wtFwpConditionValue0{
_type: cFWP_BYTE_BLOB_TYPE,
value: uintptr(unsafe.Pointer(appID)),
},
}
filter := wtFwpmFilter0{
providerKey: &base.provider,
subLayerKey: base.filters,
weight: filterWeight(weight),
numFilterConditions: 1,
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&cond)),
action: wtFwpmAction0{_type: cFWP_ACTION_PERMIT},
}
return addOutboundFilters(session, &filter, "Permit netbird daemon")
}
// permitVirtualDNSIP installs a permit filter for DNS-port traffic destined
// for the in-tunnel virtual DNS IP. Used in strict mode in lieu of
// permitTunInterface.
func permitVirtualDNSIP(session uintptr, base *baseObjects, ip netip.Addr, ports []uint16, weight uint8) error {
var merr *multierror.Error
for _, port := range ports {
if err := permitDNSToHost(session, base, ip, port, weight); err != nil {
merr = multierror.Append(merr, fmt.Errorf("permit %s:%d: %w", ip, port, err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
func permitDNSToHost(session uintptr, base *baseObjects, ip netip.Addr, port uint16, weight uint8) error {
if !ip.IsValid() {
return fmt.Errorf("invalid address")
}
var addrCond wtFwpmFilterCondition0
var layer windows.GUID
// v6 backing must outlive fwpmFilterAdd0; keep it on this stack frame.
var v6 wtFwpByteArray16
if ip.Is4() {
v4 := ip.As4()
addrCond = wtFwpmFilterCondition0{
fieldKey: cFWPM_CONDITION_IP_REMOTE_ADDRESS,
matchType: cFWP_MATCH_EQUAL,
conditionValue: wtFwpConditionValue0{
_type: cFWP_UINT32,
value: uintptr(binary.BigEndian.Uint32(v4[:])),
},
}
layer = cFWPM_LAYER_ALE_AUTH_CONNECT_V4
} else {
v6 = wtFwpByteArray16{byteArray16: ip.As16()}
addrCond = wtFwpmFilterCondition0{
fieldKey: cFWPM_CONDITION_IP_REMOTE_ADDRESS,
matchType: cFWP_MATCH_EQUAL,
conditionValue: wtFwpConditionValue0{
_type: cFWP_BYTE_ARRAY16_TYPE,
value: uintptr(unsafe.Pointer(&v6)),
},
}
layer = cFWPM_LAYER_ALE_AUTH_CONNECT_V6
}
conditions := [2]wtFwpmFilterCondition0{
addrCond,
{
fieldKey: cFWPM_CONDITION_IP_REMOTE_PORT,
matchType: cFWP_MATCH_EQUAL,
conditionValue: wtFwpConditionValue0{
_type: cFWP_UINT16,
value: uintptr(port),
},
},
}
filter := wtFwpmFilter0{
providerKey: &base.provider,
subLayerKey: base.filters,
weight: filterWeight(weight),
numFilterConditions: uint32(len(conditions)),
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&conditions[0])),
action: wtFwpmAction0{_type: cFWP_ACTION_PERMIT},
}
display, err := createWtFwpmDisplayData0(fmt.Sprintf("Permit DNS to %s:%d", ip, port), "")
if err != nil {
return wrapErr(err)
}
filter.displayData = *display
filter.layerKey = layer
var filterID uint64
if err := fwpmFilterAdd0(session, &filter, 0, &filterID); err != nil {
return wrapErr(err)
}
_ = v6
return nil
}
// blockDNSPorts installs a deny filter for outbound traffic to each of the
// given remote ports over UDP or TCP. Per-port and per-layer failures are
// accumulated; partial coverage is preferred over zero coverage.
func blockDNSPorts(session uintptr, base *baseObjects, ports []uint16, weight uint8) error {
var merr *multierror.Error
for _, port := range ports {
if err := blockDNSPort(session, base, port, weight); err != nil {
merr = multierror.Append(merr, fmt.Errorf("block port %d: %w", port, err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
func blockDNSPort(session uintptr, base *baseObjects, port uint16, weight uint8) error {
conditions := [3]wtFwpmFilterCondition0{
{
fieldKey: cFWPM_CONDITION_IP_REMOTE_PORT,
matchType: cFWP_MATCH_EQUAL,
conditionValue: wtFwpConditionValue0{
_type: cFWP_UINT16,
value: uintptr(port),
},
},
{
fieldKey: cFWPM_CONDITION_IP_PROTOCOL,
matchType: cFWP_MATCH_EQUAL,
conditionValue: wtFwpConditionValue0{
_type: cFWP_UINT8,
value: uintptr(cIPPROTO_UDP),
},
},
// Repeat the IP_PROTOCOL condition for logical OR with TCP.
{
fieldKey: cFWPM_CONDITION_IP_PROTOCOL,
matchType: cFWP_MATCH_EQUAL,
conditionValue: wtFwpConditionValue0{
_type: cFWP_UINT8,
value: uintptr(cIPPROTO_TCP),
},
},
}
filter := wtFwpmFilter0{
providerKey: &base.provider,
subLayerKey: base.filters,
weight: filterWeight(weight),
numFilterConditions: uint32(len(conditions)),
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&conditions[0])),
action: wtFwpmAction0{_type: cFWP_ACTION_BLOCK},
}
return addOutboundFilters(session, &filter, fmt.Sprintf("Block DNS port %d", port))
}
// addOutboundFilters installs the same filter on the v4 and v6 outbound ALE
// connect layers. v4 and v6 are installed independently: failure on one
// layer does not abort the other, and the accumulated errors are returned.
// Partial coverage is preferred over zero coverage.
func addOutboundFilters(session uintptr, filter *wtFwpmFilter0, name string) error {
layers := [...]struct {
layer windows.GUID
label string
}{
{cFWPM_LAYER_ALE_AUTH_CONNECT_V4, name + " (IPv4)"},
{cFWPM_LAYER_ALE_AUTH_CONNECT_V6, name + " (IPv6)"},
}
var merr *multierror.Error
for _, l := range layers {
display, err := createWtFwpmDisplayData0(l.label, "")
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("%s: %w", l.label, wrapErr(err)))
continue
}
filter.displayData = *display
filter.layerKey = l.layer
var filterID uint64
if err := fwpmFilterAdd0(session, filter, 0, &filterID); err != nil {
merr = multierror.Append(merr, fmt.Errorf("%s: %w", l.label, wrapErr(err)))
}
}
return nberrors.FormatErrorOrNil(merr)
}

View File

@@ -0,0 +1,177 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2026 NetBird GmbH. All Rights Reserved.
*
* Session lifecycle and the high-level Install/Close entry points adapted
* from wireguard-windows tunnel/firewall.
*/
package dnsfw
import (
"errors"
"fmt"
"net/netip"
"unsafe"
"github.com/hashicorp/go-multierror"
"golang.org/x/sys/windows"
nberrors "github.com/netbirdio/netbird/client/errors"
)
// installConfig is the input to installFilters.
type installConfig struct {
tunLUID uint64
daemonExe string
blockedPorts []uint16
// strict, when true, narrows the carve-out from "anything on tun" to
// "DNS only to virtualDNSIP". virtualDNSIP must be valid in this case.
strict bool
virtualDNSIP netip.Addr
}
// baseObjects holds the GUIDs of the WFP provider and sublayer registered
// for our session. Both are randomly generated per session.
type baseObjects struct {
provider windows.GUID
filters windows.GUID
}
// installFilters opens a dynamic WFP session and installs the netbird DNS
// firewall filters. Returns a zero session on hard failure (session create,
// base objects); a non-zero session with a non-nil error is a partial install
// (some per-filter installs failed) and is safe to close.
func installFilters(cfg installConfig) (session uintptr, err error) {
defer func() {
if r := recover(); r != nil {
// Dynamic session: kernel will clean up on process exit even
// if we leave the handle dangling here.
err = fmt.Errorf("panic in installFilters: %v", r)
}
}()
if len(cfg.blockedPorts) == 0 {
return 0, errors.New("dns firewall: no blocked ports configured")
}
if cfg.strict && !cfg.virtualDNSIP.IsValid() {
return 0, errors.New("dns firewall: strict mode requires a valid virtual DNS IP")
}
session, err = createSession()
if err != nil {
return 0, err
}
base, err := registerBaseObjects(session)
if err != nil {
_ = fwpmEngineClose0(session)
return 0, fmt.Errorf("register base objects: %w", err)
}
var merr *multierror.Error
if cfg.strict {
if err := permitVirtualDNSIP(session, base, cfg.virtualDNSIP, cfg.blockedPorts, 15); err != nil {
merr = multierror.Append(merr, fmt.Errorf("permit virtual dns: %w", err))
}
} else {
if err := permitTunInterface(session, base, 15, cfg.tunLUID); err != nil {
merr = multierror.Append(merr, fmt.Errorf("permit tun interface: %w", err))
}
}
if err := permitDaemonByAppID(session, base, cfg.daemonExe, 14); err != nil {
merr = multierror.Append(merr, fmt.Errorf("permit netbird daemon: %w", err))
}
if err := blockDNSPorts(session, base, cfg.blockedPorts, 10); err != nil {
merr = multierror.Append(merr, fmt.Errorf("block dns ports: %w", err))
}
return session, nberrors.FormatErrorOrNil(merr)
}
// closeSession tears down a WFP session previously opened by installFilters.
// All filters owned by the session are removed.
func closeSession(session uintptr) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic in closeSession: %v", r)
}
}()
if session == 0 {
return nil
}
if err := fwpmEngineClose0(session); err != nil {
return wrapErr(err)
}
return nil
}
func createSession() (uintptr, error) {
displayData, err := createWtFwpmDisplayData0("NetBird DNS firewall", "NetBird DNS firewall dynamic session")
if err != nil {
return 0, wrapErr(err)
}
session := wtFwpmSession0{
displayData: *displayData,
flags: cFWPM_SESSION_FLAG_DYNAMIC,
txnWaitTimeoutInMSec: windows.INFINITE,
}
var handle uintptr
if err := fwpmEngineOpen0(nil, cRPC_C_AUTHN_WINNT, nil, &session, unsafe.Pointer(&handle)); err != nil {
return 0, wrapErr(err)
}
return handle, nil
}
func registerBaseObjects(session uintptr) (*baseObjects, error) {
bo := &baseObjects{}
var err error
if bo.provider, err = windows.GenerateGUID(); err != nil {
return nil, wrapErr(err)
}
if bo.filters, err = windows.GenerateGUID(); err != nil {
return nil, wrapErr(err)
}
displayData, err := createWtFwpmDisplayData0("NetBird DNS firewall", "NetBird DNS firewall provider")
if err != nil {
return nil, wrapErr(err)
}
provider := wtFwpmProvider0{
providerKey: bo.provider,
displayData: *displayData,
}
if err := fwpmProviderAdd0(session, &provider, 0); err != nil {
return nil, wrapErr(err)
}
subDisplay, err := createWtFwpmDisplayData0("NetBird DNS firewall filters", "Permit and block filters")
if err != nil {
return nil, wrapErr(err)
}
sublayer := wtFwpmSublayer0{
subLayerKey: bo.filters,
displayData: *subDisplay,
providerKey: &bo.provider,
weight: ^uint16(0),
}
if err := fwpmSubLayerAdd0(session, &sublayer, 0); err != nil {
return nil, wrapErr(err)
}
return bo, nil
}
// daemonAppID returns the WFP App-ID byte blob for the given executable path.
func daemonAppID(path string) (*wtFwpByteBlob, error) {
pathPtr, err := windows.UTF16PtrFromString(path)
if err != nil {
return nil, wrapErr(err)
}
var appID *wtFwpByteBlob
if err := fwpmGetAppIdFromFileName0(pathPtr, unsafe.Pointer(&appID)); err != nil {
return nil, wrapErr(err)
}
return appID, nil
}

View File

@@ -0,0 +1,38 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*
* Adapted from wireguard-windows tunnel/firewall/syscall_windows.go.
*/
package dnsfw
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmengineopen0
//sys fwpmEngineOpen0(serverName *uint16, authnService wtRpcCAuthN, authIdentity *uintptr, session *wtFwpmSession0, engineHandle unsafe.Pointer) (err error) [failretval!=0] = fwpuclnt.FwpmEngineOpen0
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmengineclose0
//sys fwpmEngineClose0(engineHandle uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmEngineClose0
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmsublayeradd0
//sys fwpmSubLayerAdd0(engineHandle uintptr, subLayer *wtFwpmSublayer0, sd uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmSubLayerAdd0
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmgetappidfromfilename0
//sys fwpmGetAppIdFromFileName0(fileName *uint16, appID unsafe.Pointer) (err error) [failretval!=0] = fwpuclnt.FwpmGetAppIdFromFileName0
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmfreememory0
//sys fwpmFreeMemory0(p unsafe.Pointer) = fwpuclnt.FwpmFreeMemory0
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmfilteradd0
//sys fwpmFilterAdd0(engineHandle uintptr, filter *wtFwpmFilter0, sd uintptr, id *uint64) (err error) [failretval!=0] = fwpuclnt.FwpmFilterAdd0
// https://docs.microsoft.com/en-us/windows/desktop/api/Fwpmu/nf-fwpmu-fwpmtransactionbegin0
//sys fwpmTransactionBegin0(engineHandle uintptr, flags uint32) (err error) [failretval!=0] = fwpuclnt.FwpmTransactionBegin0
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmtransactioncommit0
//sys fwpmTransactionCommit0(engineHandle uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmTransactionCommit0
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmtransactionabort0
//sys fwpmTransactionAbort0(engineHandle uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmTransactionAbort0
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmprovideradd0
//sys fwpmProviderAdd0(engineHandle uintptr, provider *wtFwpmProvider0, sd uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmProviderAdd0

View File

@@ -0,0 +1,414 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*
* Adapted from wireguard-windows tunnel/firewall/types_windows.go.
*/
package dnsfw
import "golang.org/x/sys/windows"
const (
anysizeArray = 1 // ANYSIZE_ARRAY defined in winnt.h
wtFwpBitmapArray64_Size = 8
wtFwpByteArray16_Size = 16
wtFwpByteArray6_Size = 6
wtFwpmAction0_Size = 20
wtFwpmAction0_filterType_Offset = 4
wtFwpV4AddrAndMask_Size = 8
wtFwpV4AddrAndMask_mask_Offset = 4
wtFwpV6AddrAndMask_Size = 17
wtFwpV6AddrAndMask_prefixLength_Offset = 16
)
type wtFwpActionFlag uint32
const (
cFWP_ACTION_FLAG_TERMINATING wtFwpActionFlag = 0x00001000
cFWP_ACTION_FLAG_NON_TERMINATING wtFwpActionFlag = 0x00002000
cFWP_ACTION_FLAG_CALLOUT wtFwpActionFlag = 0x00004000
)
// FWP_ACTION_TYPE defined in fwptypes.h
type wtFwpActionType uint32
const (
cFWP_ACTION_BLOCK wtFwpActionType = wtFwpActionType(0x00000001 | cFWP_ACTION_FLAG_TERMINATING)
cFWP_ACTION_PERMIT wtFwpActionType = wtFwpActionType(0x00000002 | cFWP_ACTION_FLAG_TERMINATING)
cFWP_ACTION_CALLOUT_TERMINATING wtFwpActionType = wtFwpActionType(0x00000003 | cFWP_ACTION_FLAG_CALLOUT | cFWP_ACTION_FLAG_TERMINATING)
cFWP_ACTION_CALLOUT_INSPECTION wtFwpActionType = wtFwpActionType(0x00000004 | cFWP_ACTION_FLAG_CALLOUT | cFWP_ACTION_FLAG_NON_TERMINATING)
cFWP_ACTION_CALLOUT_UNKNOWN wtFwpActionType = wtFwpActionType(0x00000005 | cFWP_ACTION_FLAG_CALLOUT)
cFWP_ACTION_CONTINUE wtFwpActionType = wtFwpActionType(0x00000006 | cFWP_ACTION_FLAG_NON_TERMINATING)
cFWP_ACTION_NONE wtFwpActionType = 0x00000007
cFWP_ACTION_NONE_NO_MATCH wtFwpActionType = 0x00000008
cFWP_ACTION_BITMAP_INDEX_SET wtFwpActionType = 0x00000009
)
// FWP_BYTE_BLOB defined in fwptypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_byte_blob_)
type wtFwpByteBlob struct {
size uint32
data *uint8
}
// FWP_MATCH_TYPE defined in fwptypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ne-fwptypes-fwp_match_type_)
type wtFwpMatchType uint32
const (
cFWP_MATCH_EQUAL wtFwpMatchType = 0
cFWP_MATCH_GREATER wtFwpMatchType = cFWP_MATCH_EQUAL + 1
cFWP_MATCH_LESS wtFwpMatchType = cFWP_MATCH_GREATER + 1
cFWP_MATCH_GREATER_OR_EQUAL wtFwpMatchType = cFWP_MATCH_LESS + 1
cFWP_MATCH_LESS_OR_EQUAL wtFwpMatchType = cFWP_MATCH_GREATER_OR_EQUAL + 1
cFWP_MATCH_RANGE wtFwpMatchType = cFWP_MATCH_LESS_OR_EQUAL + 1
cFWP_MATCH_FLAGS_ALL_SET wtFwpMatchType = cFWP_MATCH_RANGE + 1
cFWP_MATCH_FLAGS_ANY_SET wtFwpMatchType = cFWP_MATCH_FLAGS_ALL_SET + 1
cFWP_MATCH_FLAGS_NONE_SET wtFwpMatchType = cFWP_MATCH_FLAGS_ANY_SET + 1
cFWP_MATCH_EQUAL_CASE_INSENSITIVE wtFwpMatchType = cFWP_MATCH_FLAGS_NONE_SET + 1
cFWP_MATCH_NOT_EQUAL wtFwpMatchType = cFWP_MATCH_EQUAL_CASE_INSENSITIVE + 1
cFWP_MATCH_PREFIX wtFwpMatchType = cFWP_MATCH_NOT_EQUAL + 1
cFWP_MATCH_NOT_PREFIX wtFwpMatchType = cFWP_MATCH_PREFIX + 1
cFWP_MATCH_TYPE_MAX wtFwpMatchType = cFWP_MATCH_NOT_PREFIX + 1
)
// FWPM_ACTION0 defined in fwpmtypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_action0_)
type wtFwpmAction0 struct {
_type wtFwpActionType
filterType windows.GUID // Windows type: GUID
}
// Defined in fwpmu.h. 4cd62a49-59c3-4969-b7f3-bda5d32890a4
var cFWPM_CONDITION_IP_LOCAL_INTERFACE = windows.GUID{
Data1: 0x4cd62a49,
Data2: 0x59c3,
Data3: 0x4969,
Data4: [8]byte{0xb7, 0xf3, 0xbd, 0xa5, 0xd3, 0x28, 0x90, 0xa4},
}
// Defined in fwpmu.h. b235ae9a-1d64-49b8-a44c-5ff3d9095045
var cFWPM_CONDITION_IP_REMOTE_ADDRESS = windows.GUID{
Data1: 0xb235ae9a,
Data2: 0x1d64,
Data3: 0x49b8,
Data4: [8]byte{0xa4, 0x4c, 0x5f, 0xf3, 0xd9, 0x09, 0x50, 0x45},
}
// Defined in fwpmu.h. 3971ef2b-623e-4f9a-8cb1-6e79b806b9a7
var cFWPM_CONDITION_IP_PROTOCOL = windows.GUID{
Data1: 0x3971ef2b,
Data2: 0x623e,
Data3: 0x4f9a,
Data4: [8]byte{0x8c, 0xb1, 0x6e, 0x79, 0xb8, 0x06, 0xb9, 0xa7},
}
// Defined in fwpmu.h. 0c1ba1af-5765-453f-af22-a8f791ac775b
var cFWPM_CONDITION_IP_LOCAL_PORT = windows.GUID{
Data1: 0x0c1ba1af,
Data2: 0x5765,
Data3: 0x453f,
Data4: [8]byte{0xaf, 0x22, 0xa8, 0xf7, 0x91, 0xac, 0x77, 0x5b},
}
// Defined in fwpmu.h. c35a604d-d22b-4e1a-91b4-68f674ee674b
var cFWPM_CONDITION_IP_REMOTE_PORT = windows.GUID{
Data1: 0xc35a604d,
Data2: 0xd22b,
Data3: 0x4e1a,
Data4: [8]byte{0x91, 0xb4, 0x68, 0xf6, 0x74, 0xee, 0x67, 0x4b},
}
// Defined in fwpmu.h. d78e1e87-8644-4ea5-9437-d809ecefc971
var cFWPM_CONDITION_ALE_APP_ID = windows.GUID{
Data1: 0xd78e1e87,
Data2: 0x8644,
Data3: 0x4ea5,
Data4: [8]byte{0x94, 0x37, 0xd8, 0x09, 0xec, 0xef, 0xc9, 0x71},
}
// af043a0a-b34d-4f86-979c-c90371af6e66
var cFWPM_CONDITION_ALE_USER_ID = windows.GUID{
Data1: 0xaf043a0a,
Data2: 0xb34d,
Data3: 0x4f86,
Data4: [8]byte{0x97, 0x9c, 0xc9, 0x03, 0x71, 0xaf, 0x6e, 0x66},
}
// d9ee00de-c1ef-4617-bfe3-ffd8f5a08957
var cFWPM_CONDITION_IP_LOCAL_ADDRESS = windows.GUID{
Data1: 0xd9ee00de,
Data2: 0xc1ef,
Data3: 0x4617,
Data4: [8]byte{0xbf, 0xe3, 0xff, 0xd8, 0xf5, 0xa0, 0x89, 0x57},
}
var (
cFWPM_CONDITION_ICMP_TYPE = cFWPM_CONDITION_IP_LOCAL_PORT
cFWPM_CONDITION_ICMP_CODE = cFWPM_CONDITION_IP_REMOTE_PORT
)
// 7bc43cbf-37ba-45f1-b74a-82ff518eeb10
var cFWPM_CONDITION_L2_FLAGS = windows.GUID{
Data1: 0x7bc43cbf,
Data2: 0x37ba,
Data3: 0x45f1,
Data4: [8]byte{0xb7, 0x4a, 0x82, 0xff, 0x51, 0x8e, 0xeb, 0x10},
}
type wtFwpmL2Flags uint32
const cFWP_CONDITION_L2_IS_VM2VM wtFwpmL2Flags = 0x00000010
var cFWPM_CONDITION_FLAGS = windows.GUID{
Data1: 0x632ce23b,
Data2: 0x5167,
Data3: 0x435c,
Data4: [8]byte{0x86, 0xd7, 0xe9, 0x03, 0x68, 0x4a, 0xa8, 0x0c},
}
type wtFwpmFlags uint32
const cFWP_CONDITION_FLAG_IS_LOOPBACK wtFwpmFlags = 0x00000001
// Defined in fwpmtypes.h
type wtFwpmFilterFlags uint32
const (
cFWPM_FILTER_FLAG_NONE wtFwpmFilterFlags = 0x00000000
cFWPM_FILTER_FLAG_PERSISTENT wtFwpmFilterFlags = 0x00000001
cFWPM_FILTER_FLAG_BOOTTIME wtFwpmFilterFlags = 0x00000002
cFWPM_FILTER_FLAG_HAS_PROVIDER_CONTEXT wtFwpmFilterFlags = 0x00000004
cFWPM_FILTER_FLAG_CLEAR_ACTION_RIGHT wtFwpmFilterFlags = 0x00000008
cFWPM_FILTER_FLAG_PERMIT_IF_CALLOUT_UNREGISTERED wtFwpmFilterFlags = 0x00000010
cFWPM_FILTER_FLAG_DISABLED wtFwpmFilterFlags = 0x00000020
cFWPM_FILTER_FLAG_INDEXED wtFwpmFilterFlags = 0x00000040
cFWPM_FILTER_FLAG_HAS_SECURITY_REALM_PROVIDER_CONTEXT wtFwpmFilterFlags = 0x00000080
cFWPM_FILTER_FLAG_SYSTEMOS_ONLY wtFwpmFilterFlags = 0x00000100
cFWPM_FILTER_FLAG_GAMEOS_ONLY wtFwpmFilterFlags = 0x00000200
cFWPM_FILTER_FLAG_SILENT_MODE wtFwpmFilterFlags = 0x00000400
cFWPM_FILTER_FLAG_IPSEC_NO_ACQUIRE_INITIATE wtFwpmFilterFlags = 0x00000800
)
// FWPM_LAYER_ALE_AUTH_CONNECT_V4 (c38d57d1-05a7-4c33-904f-7fbceee60e82) defined in fwpmu.h
var cFWPM_LAYER_ALE_AUTH_CONNECT_V4 = windows.GUID{
Data1: 0xc38d57d1,
Data2: 0x05a7,
Data3: 0x4c33,
Data4: [8]byte{0x90, 0x4f, 0x7f, 0xbc, 0xee, 0xe6, 0x0e, 0x82},
}
// e1cd9fe7-f4b5-4273-96c0-592e487b8650
var cFWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4 = windows.GUID{
Data1: 0xe1cd9fe7,
Data2: 0xf4b5,
Data3: 0x4273,
Data4: [8]byte{0x96, 0xc0, 0x59, 0x2e, 0x48, 0x7b, 0x86, 0x50},
}
// FWPM_LAYER_ALE_AUTH_CONNECT_V6 (4a72393b-319f-44bc-84c3-ba54dcb3b6b4) defined in fwpmu.h
var cFWPM_LAYER_ALE_AUTH_CONNECT_V6 = windows.GUID{
Data1: 0x4a72393b,
Data2: 0x319f,
Data3: 0x44bc,
Data4: [8]byte{0x84, 0xc3, 0xba, 0x54, 0xdc, 0xb3, 0xb6, 0xb4},
}
// a3b42c97-9f04-4672-b87e-cee9c483257f
var cFWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V6 = windows.GUID{
Data1: 0xa3b42c97,
Data2: 0x9f04,
Data3: 0x4672,
Data4: [8]byte{0xb8, 0x7e, 0xce, 0xe9, 0xc4, 0x83, 0x25, 0x7f},
}
// 94c44912-9d6f-4ebf-b995-05ab8a088d1b
var cFWPM_LAYER_OUTBOUND_MAC_FRAME_NATIVE = windows.GUID{
Data1: 0x94c44912,
Data2: 0x9d6f,
Data3: 0x4ebf,
Data4: [8]byte{0xb9, 0x95, 0x05, 0xab, 0x8a, 0x08, 0x8d, 0x1b},
}
// d4220bd3-62ce-4f08-ae88-b56e8526df50
var cFWPM_LAYER_INBOUND_MAC_FRAME_NATIVE = windows.GUID{
Data1: 0xd4220bd3,
Data2: 0x62ce,
Data3: 0x4f08,
Data4: [8]byte{0xae, 0x88, 0xb5, 0x6e, 0x85, 0x26, 0xdf, 0x50},
}
// FWP_BITMAP_ARRAY64 defined in fwtypes.h
type wtFwpBitmapArray64 struct {
bitmapArray64 [8]uint8 // Windows type: [8]UINT8
}
// FWP_BYTE_ARRAY6 defined in fwtypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_byte_array6_)
type wtFwpByteArray6 struct {
byteArray6 [6]uint8 // Windows type: [6]UINT8
}
// FWP_BYTE_ARRAY16 defined in fwptypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_byte_array16_)
type wtFwpByteArray16 struct {
byteArray16 [16]uint8 // Windows type [16]UINT8
}
// FWP_CONDITION_VALUE0 defined in fwptypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_condition_value0).
type wtFwpConditionValue0 wtFwpValue0
// FWP_DATA_TYPE defined in fwptypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ne-fwptypes-fwp_data_type_)
type wtFwpDataType uint
const (
cFWP_EMPTY wtFwpDataType = 0
cFWP_UINT8 wtFwpDataType = cFWP_EMPTY + 1
cFWP_UINT16 wtFwpDataType = cFWP_UINT8 + 1
cFWP_UINT32 wtFwpDataType = cFWP_UINT16 + 1
cFWP_UINT64 wtFwpDataType = cFWP_UINT32 + 1
cFWP_INT8 wtFwpDataType = cFWP_UINT64 + 1
cFWP_INT16 wtFwpDataType = cFWP_INT8 + 1
cFWP_INT32 wtFwpDataType = cFWP_INT16 + 1
cFWP_INT64 wtFwpDataType = cFWP_INT32 + 1
cFWP_FLOAT wtFwpDataType = cFWP_INT64 + 1
cFWP_DOUBLE wtFwpDataType = cFWP_FLOAT + 1
cFWP_BYTE_ARRAY16_TYPE wtFwpDataType = cFWP_DOUBLE + 1
cFWP_BYTE_BLOB_TYPE wtFwpDataType = cFWP_BYTE_ARRAY16_TYPE + 1
cFWP_SID wtFwpDataType = cFWP_BYTE_BLOB_TYPE + 1
cFWP_SECURITY_DESCRIPTOR_TYPE wtFwpDataType = cFWP_SID + 1
cFWP_TOKEN_INFORMATION_TYPE wtFwpDataType = cFWP_SECURITY_DESCRIPTOR_TYPE + 1
cFWP_TOKEN_ACCESS_INFORMATION_TYPE wtFwpDataType = cFWP_TOKEN_INFORMATION_TYPE + 1
cFWP_UNICODE_STRING_TYPE wtFwpDataType = cFWP_TOKEN_ACCESS_INFORMATION_TYPE + 1
cFWP_BYTE_ARRAY6_TYPE wtFwpDataType = cFWP_UNICODE_STRING_TYPE + 1
cFWP_BITMAP_INDEX_TYPE wtFwpDataType = cFWP_BYTE_ARRAY6_TYPE + 1
cFWP_BITMAP_ARRAY64_TYPE wtFwpDataType = cFWP_BITMAP_INDEX_TYPE + 1
cFWP_SINGLE_DATA_TYPE_MAX wtFwpDataType = 0xff
cFWP_V4_ADDR_MASK wtFwpDataType = cFWP_SINGLE_DATA_TYPE_MAX + 1
cFWP_V6_ADDR_MASK wtFwpDataType = cFWP_V4_ADDR_MASK + 1
cFWP_RANGE_TYPE wtFwpDataType = cFWP_V6_ADDR_MASK + 1
cFWP_DATA_TYPE_MAX wtFwpDataType = cFWP_RANGE_TYPE + 1
)
// FWP_V4_ADDR_AND_MASK defined in fwptypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_v4_addr_and_mask).
type wtFwpV4AddrAndMask struct {
addr uint32
mask uint32
}
// FWP_V6_ADDR_AND_MASK defined in fwptypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_v6_addr_and_mask).
type wtFwpV6AddrAndMask struct {
addr [16]uint8
prefixLength uint8
}
// FWP_VALUE0 defined in fwptypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_value0_)
type wtFwpValue0 struct {
_type wtFwpDataType
value uintptr
}
// FWPM_DISPLAY_DATA0 defined in fwptypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwpm_display_data0).
type wtFwpmDisplayData0 struct {
name *uint16 // Windows type: *wchar_t
description *uint16 // Windows type: *wchar_t
}
// FWPM_FILTER_CONDITION0 defined in fwpmtypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_filter_condition0).
type wtFwpmFilterCondition0 struct {
fieldKey windows.GUID // Windows type: GUID
matchType wtFwpMatchType
conditionValue wtFwpConditionValue0
}
// FWPM_PROVIDER0 defined in fwpmtypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_provider0_)
type wtFwpProvider0 struct {
providerKey windows.GUID // Windows type: GUID
displayData wtFwpmDisplayData0
flags uint32
providerData wtFwpByteBlob
serviceName *uint16 // Windows type: *wchar_t
}
type wtFwpmSessionFlagsValue uint32
const (
cFWPM_SESSION_FLAG_DYNAMIC wtFwpmSessionFlagsValue = 0x00000001 // FWPM_SESSION_FLAG_DYNAMIC defined in fwpmtypes.h
)
// FWPM_SESSION0 defined in fwpmtypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_session0).
type wtFwpmSession0 struct {
sessionKey windows.GUID // Windows type: GUID
displayData wtFwpmDisplayData0
flags wtFwpmSessionFlagsValue // Windows type UINT32
txnWaitTimeoutInMSec uint32
processId uint32 // Windows type: DWORD
sid *windows.SID
username *uint16 // Windows type: *wchar_t
kernelMode uint8 // Windows type: BOOL
}
type wtFwpmSublayerFlags uint32
const (
cFWPM_SUBLAYER_FLAG_PERSISTENT wtFwpmSublayerFlags = 0x00000001 // FWPM_SUBLAYER_FLAG_PERSISTENT defined in fwpmtypes.h
)
// FWPM_SUBLAYER0 defined in fwpmtypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_sublayer0_)
type wtFwpmSublayer0 struct {
subLayerKey windows.GUID // Windows type: GUID
displayData wtFwpmDisplayData0
flags wtFwpmSublayerFlags
providerKey *windows.GUID // Windows type: *GUID
providerData wtFwpByteBlob
weight uint16
}
// Defined in rpcdce.h
type wtRpcCAuthN uint32
const (
cRPC_C_AUTHN_NONE wtRpcCAuthN = 0
cRPC_C_AUTHN_WINNT wtRpcCAuthN = 10
cRPC_C_AUTHN_DEFAULT wtRpcCAuthN = 0xFFFFFFFF
)
// FWPM_PROVIDER0 defined in fwpmtypes.h
// (https://docs.microsoft.com/sv-se/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_provider0).
type wtFwpmProvider0 struct {
providerKey windows.GUID
displayData wtFwpmDisplayData0
flags uint32
providerData wtFwpByteBlob
serviceName *uint16
}
type wtIPProto uint32
const (
cIPPROTO_ICMP wtIPProto = 1
cIPPROTO_ICMPV6 wtIPProto = 58
cIPPROTO_TCP wtIPProto = 6
cIPPROTO_UDP wtIPProto = 17
)
const (
cFWP_ACTRL_MATCH_FILTER = 1
)

Some files were not shown because too many files have changed in this diff Show More