mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
The cleanup path tried to delete the jump rule to NETBIRD-NAT-OUTPUT unconditionally, producing a noisy debug log when the chain was never created (common for the v6 router when no OutputDNAT rules exist). Check ChainExists first, consistent with the chain deletion loop below.
1148 lines
32 KiB
Go
1148 lines
32 KiB
Go
//go:build !android
|
|
|
|
package iptables
|
|
|
|
import (
|
|
"fmt"
|
|
"net/netip"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/coreos/go-iptables/iptables"
|
|
"github.com/hashicorp/go-multierror"
|
|
ipset "github.com/lrh3321/ipset-go"
|
|
log "github.com/sirupsen/logrus"
|
|
|
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
|
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
|
nbnet "github.com/netbirdio/netbird/client/net"
|
|
)
|
|
|
|
// constants needed to manage and create iptable rules
|
|
const (
|
|
tableFilter = "filter"
|
|
tableNat = "nat"
|
|
tableMangle = "mangle"
|
|
|
|
chainPOSTROUTING = "POSTROUTING"
|
|
chainPREROUTING = "PREROUTING"
|
|
chainFORWARD = "FORWARD"
|
|
chainRTNAT = "NETBIRD-RT-NAT"
|
|
chainRTFWDIN = "NETBIRD-RT-FWD-IN"
|
|
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
|
|
chainRTPRE = "NETBIRD-RT-PRE"
|
|
chainRTRDR = "NETBIRD-RT-RDR"
|
|
chainNATOutput = "NETBIRD-NAT-OUTPUT"
|
|
chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP"
|
|
routingFinalForwardJump = "ACCEPT"
|
|
routingFinalNatJump = "MASQUERADE"
|
|
|
|
jumpManglePre = "jump-mangle-pre"
|
|
jumpNatPre = "jump-nat-pre"
|
|
jumpNatPost = "jump-nat-post"
|
|
jumpNatOutput = "jump-nat-output"
|
|
jumpMSSClamp = "jump-mss-clamp"
|
|
markManglePre = "mark-mangle-pre"
|
|
markManglePost = "mark-mangle-post"
|
|
matchSet = "--match-set"
|
|
|
|
dnatSuffix = "_dnat"
|
|
snatSuffix = "_snat"
|
|
fwdSuffix = "_fwd"
|
|
|
|
// ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
|
|
ipv4TCPHeaderSize = 40
|
|
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
|
|
ipv6TCPHeaderSize = 60
|
|
)
|
|
|
|
type ruleInfo struct {
|
|
chain string
|
|
table string
|
|
rule []string
|
|
}
|
|
|
|
type routeFilteringRuleParams struct {
|
|
Source firewall.Network
|
|
Destination firewall.Network
|
|
Proto firewall.Protocol
|
|
SPort *firewall.Port
|
|
DPort *firewall.Port
|
|
Direction firewall.RuleDirection
|
|
Action firewall.Action
|
|
}
|
|
|
|
type routeRules map[string][]string
|
|
|
|
// the ipset library currently does not support comments, so we use the name only (string)
|
|
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
|
|
|
|
type router struct {
|
|
iptablesClient *iptables.IPTables
|
|
rules routeRules
|
|
ipsetCounter *ipsetCounter
|
|
wgIface iFaceMapper
|
|
legacyManagement bool
|
|
mtu uint16
|
|
v6 bool
|
|
|
|
stateManager *statemanager.Manager
|
|
ipFwdState *ipfwdstate.IPForwardingState
|
|
}
|
|
|
|
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint16) (*router, error) {
|
|
r := &router{
|
|
iptablesClient: iptablesClient,
|
|
rules: make(map[string][]string),
|
|
wgIface: wgIface,
|
|
mtu: mtu,
|
|
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
|
|
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
|
}
|
|
|
|
r.ipsetCounter = refcounter.New(
|
|
func(name string, sources []netip.Prefix) (struct{}, error) {
|
|
return struct{}{}, r.createIpSet(name, sources)
|
|
},
|
|
func(name string, _ struct{}) error {
|
|
return r.deleteIpSet(name)
|
|
},
|
|
)
|
|
|
|
return r, nil
|
|
}
|
|
|
|
func (r *router) init(stateManager *statemanager.Manager) error {
|
|
r.stateManager = stateManager
|
|
|
|
if err := r.cleanUpDefaultForwardRules(); err != nil {
|
|
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
|
}
|
|
|
|
if err := r.createContainers(); err != nil {
|
|
return fmt.Errorf("create containers: %w", err)
|
|
}
|
|
|
|
if err := r.setupDataPlaneMark(); err != nil {
|
|
log.Errorf("failed to set up data plane mark: %v", err)
|
|
}
|
|
|
|
r.updateState()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *router) AddRouteFiltering(
|
|
id []byte,
|
|
sources []netip.Prefix,
|
|
destination firewall.Network,
|
|
proto firewall.Protocol,
|
|
sPort *firewall.Port,
|
|
dPort *firewall.Port,
|
|
action firewall.Action,
|
|
) (firewall.Rule, error) {
|
|
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
|
if _, ok := r.rules[string(ruleKey)]; ok {
|
|
return ruleKey, nil
|
|
}
|
|
|
|
var source firewall.Network
|
|
if len(sources) > 1 {
|
|
source.Set = firewall.NewPrefixSet(sources)
|
|
} else if len(sources) > 0 {
|
|
source.Prefix = sources[0]
|
|
}
|
|
|
|
params := routeFilteringRuleParams{
|
|
Source: source,
|
|
Destination: destination,
|
|
Proto: proto,
|
|
SPort: sPort,
|
|
DPort: dPort,
|
|
Action: action,
|
|
}
|
|
|
|
rule, err := r.genRouteRuleSpec(params, sources)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("generate route rule spec: %w", err)
|
|
}
|
|
|
|
// Insert DROP rules at the beginning, append ACCEPT rules at the end
|
|
if action == firewall.ActionDrop {
|
|
// after the established rule
|
|
err = r.iptablesClient.Insert(tableFilter, chainRTFWDIN, 2, rule...)
|
|
} else {
|
|
err = r.iptablesClient.Append(tableFilter, chainRTFWDIN, rule...)
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("add route rule: %v", err)
|
|
}
|
|
|
|
r.rules[string(ruleKey)] = rule
|
|
|
|
r.updateState()
|
|
|
|
return ruleKey, nil
|
|
}
|
|
|
|
func (r *router) hasRule(id string) bool {
|
|
_, ok := r.rules[id]
|
|
return ok
|
|
}
|
|
|
|
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
|
ruleKey := rule.ID()
|
|
|
|
if rule, exists := r.rules[ruleKey]; exists {
|
|
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, rule...); err != nil {
|
|
return fmt.Errorf("delete route rule: %v", err)
|
|
}
|
|
delete(r.rules, ruleKey)
|
|
|
|
if err := r.decrementSetCounter(rule); err != nil {
|
|
return fmt.Errorf("decrement ipset counter: %w", err)
|
|
}
|
|
} else {
|
|
log.Debugf("route rule %s not found", ruleKey)
|
|
}
|
|
|
|
r.updateState()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *router) decrementSetCounter(rule []string) error {
|
|
sets := r.findSets(rule)
|
|
var merr *multierror.Error
|
|
for _, setName := range sets {
|
|
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
|
merr = multierror.Append(merr, fmt.Errorf("decrement counter: %w", err))
|
|
}
|
|
}
|
|
|
|
return nberrors.FormatErrorOrNil(merr)
|
|
}
|
|
|
|
func (r *router) findSets(rule []string) []string {
|
|
var sets []string
|
|
for i, arg := range rule {
|
|
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
|
|
sets = append(sets, rule[i+3])
|
|
}
|
|
}
|
|
return sets
|
|
}
|
|
|
|
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
|
|
if err := r.createIPSet(setName); err != nil {
|
|
return fmt.Errorf("create set %s: %w", setName, err)
|
|
}
|
|
|
|
for _, prefix := range sources {
|
|
if err := r.addPrefixToIPSet(setName, prefix); err != nil {
|
|
return fmt.Errorf("add element to set %s: %w", setName, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *router) deleteIpSet(setName string) error {
|
|
if err := r.destroyIPSet(setName); err != nil {
|
|
return fmt.Errorf("destroy set %s: %w", setName, err)
|
|
}
|
|
|
|
log.Debugf("Deleted unused ipset %s", setName)
|
|
return nil
|
|
}
|
|
|
|
// AddNatRule inserts an iptables rule pair into the nat chain
|
|
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
|
if r.legacyManagement {
|
|
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
|
if err := r.addLegacyRouteRule(pair); err != nil {
|
|
return fmt.Errorf("add legacy routing rule: %w", err)
|
|
}
|
|
}
|
|
|
|
if !pair.Masquerade {
|
|
return nil
|
|
}
|
|
|
|
if err := r.addNatRule(pair); err != nil {
|
|
return fmt.Errorf("add nat rule: %w", err)
|
|
}
|
|
|
|
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
|
|
return fmt.Errorf("add inverse nat rule: %w", err)
|
|
}
|
|
|
|
r.updateState()
|
|
|
|
return nil
|
|
}
|
|
|
|
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
|
|
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
|
if pair.Masquerade {
|
|
if err := r.removeNatRule(pair); err != nil {
|
|
return fmt.Errorf("remove nat rule: %w", err)
|
|
}
|
|
|
|
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
|
return fmt.Errorf("remove inverse nat rule: %w", err)
|
|
}
|
|
}
|
|
|
|
if err := r.removeLegacyRouteRule(pair); err != nil {
|
|
return fmt.Errorf("remove legacy routing rule: %w", err)
|
|
}
|
|
|
|
r.updateState()
|
|
|
|
return nil
|
|
}
|
|
|
|
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
|
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
|
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
|
|
|
if err := r.removeLegacyRouteRule(pair); err != nil {
|
|
return err
|
|
}
|
|
|
|
rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
|
|
if err := r.iptablesClient.Append(tableFilter, chainRTFWDIN, rule...); err != nil {
|
|
return fmt.Errorf("add legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
|
}
|
|
|
|
r.rules[ruleKey] = rule
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
|
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
|
|
|
if rule, exists := r.rules[ruleKey]; exists {
|
|
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWDIN, rule...); err != nil {
|
|
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
|
}
|
|
delete(r.rules, ruleKey)
|
|
|
|
if err := r.decrementSetCounter(rule); err != nil {
|
|
return fmt.Errorf("decrement ipset counter: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetLegacyManagement returns the current legacy management mode
|
|
func (r *router) GetLegacyManagement() bool {
|
|
return r.legacyManagement
|
|
}
|
|
|
|
// SetLegacyManagement sets the route manager to use legacy management mode
|
|
func (r *router) SetLegacyManagement(isLegacy bool) {
|
|
r.legacyManagement = isLegacy
|
|
}
|
|
|
|
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
|
|
func (r *router) RemoveAllLegacyRouteRules() error {
|
|
var merr *multierror.Error
|
|
for k, rule := range r.rules {
|
|
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
|
|
continue
|
|
}
|
|
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWDIN, rule...); err != nil {
|
|
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
|
|
} else {
|
|
delete(r.rules, k)
|
|
}
|
|
}
|
|
|
|
r.updateState()
|
|
|
|
return nberrors.FormatErrorOrNil(merr)
|
|
}
|
|
|
|
func (r *router) Reset() error {
|
|
var merr *multierror.Error
|
|
if err := r.cleanUpDefaultForwardRules(); err != nil {
|
|
merr = multierror.Append(merr, err)
|
|
}
|
|
|
|
if err := r.ipsetCounter.Flush(); err != nil {
|
|
merr = multierror.Append(merr, err)
|
|
}
|
|
|
|
if err := r.cleanupDataPlaneMark(); err != nil {
|
|
merr = multierror.Append(merr, err)
|
|
}
|
|
|
|
r.rules = make(map[string][]string)
|
|
r.updateState()
|
|
|
|
return nberrors.FormatErrorOrNil(merr)
|
|
}
|
|
|
|
func (r *router) cleanUpDefaultForwardRules() error {
|
|
if err := r.cleanJumpRules(); err != nil {
|
|
return fmt.Errorf("clean jump rules: %w", err)
|
|
}
|
|
|
|
log.Debug("flushing routing related tables")
|
|
|
|
// Remove jump rules from built-in chains before deleting custom chains,
|
|
// otherwise the chain deletion fails with "device or resource busy".
|
|
if ok, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput); err != nil {
|
|
return fmt.Errorf("check chain %s: %w", chainNATOutput, err)
|
|
} else if ok {
|
|
jumpRule := []string{"-j", chainNATOutput}
|
|
if err := r.iptablesClient.Delete(tableNat, "OUTPUT", jumpRule...); err != nil {
|
|
log.Debugf("clean OUTPUT jump rule: %v", err)
|
|
}
|
|
}
|
|
|
|
for _, chainInfo := range []struct {
|
|
chain string
|
|
table string
|
|
}{
|
|
{chainRTFWDIN, tableFilter},
|
|
{chainRTFWDOUT, tableFilter},
|
|
{chainRTPRE, tableMangle},
|
|
{chainRTNAT, tableNat},
|
|
{chainRTRDR, tableNat},
|
|
{chainNATOutput, tableNat},
|
|
{chainRTMSSCLAMP, tableMangle},
|
|
} {
|
|
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
|
|
if err != nil {
|
|
return fmt.Errorf("check chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
|
} else if ok {
|
|
if err = r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
|
|
return fmt.Errorf("clear and delete chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *router) createContainers() error {
|
|
for _, chainInfo := range []struct {
|
|
chain string
|
|
table string
|
|
}{
|
|
{chainRTFWDIN, tableFilter},
|
|
{chainRTFWDOUT, tableFilter},
|
|
{chainRTPRE, tableMangle},
|
|
{chainRTNAT, tableNat},
|
|
{chainRTRDR, tableNat},
|
|
{chainRTMSSCLAMP, tableMangle},
|
|
} {
|
|
// Fallback: clear chains that survived an unclean shutdown.
|
|
if ok, _ := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain); ok {
|
|
if err := r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
|
|
log.Warnf("clear stale chain %s in %s: %v", chainInfo.chain, chainInfo.table, err)
|
|
}
|
|
}
|
|
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
|
|
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
|
}
|
|
}
|
|
|
|
if err := r.insertEstablishedRule(chainRTFWDIN); err != nil {
|
|
return fmt.Errorf("insert established rule: %w", err)
|
|
}
|
|
|
|
if err := r.insertEstablishedRule(chainRTFWDOUT); err != nil {
|
|
return fmt.Errorf("insert established rule: %w", err)
|
|
}
|
|
|
|
if err := r.addPostroutingRules(); err != nil {
|
|
return fmt.Errorf("add static nat rules: %w", err)
|
|
}
|
|
|
|
if err := r.addJumpRules(); err != nil {
|
|
return fmt.Errorf("add jump rules: %w", err)
|
|
}
|
|
|
|
if err := r.addMSSClampingRules(); err != nil {
|
|
log.Errorf("failed to add MSS clamping rules: %s", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// setupDataPlaneMark configures the fwmark for the data plane
|
|
func (r *router) setupDataPlaneMark() error {
|
|
var merr *multierror.Error
|
|
preRule := []string{
|
|
"-i", r.wgIface.Name(),
|
|
"-m", "conntrack", "--ctstate", "NEW",
|
|
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkIn),
|
|
}
|
|
|
|
if err := r.iptablesClient.AppendUnique(tableMangle, chainPREROUTING, preRule...); err != nil {
|
|
merr = multierror.Append(merr, fmt.Errorf("add mangle prerouting rule: %w", err))
|
|
} else {
|
|
r.rules[markManglePre] = preRule
|
|
}
|
|
|
|
postRule := []string{
|
|
"-o", r.wgIface.Name(),
|
|
"-m", "conntrack", "--ctstate", "NEW",
|
|
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkOut),
|
|
}
|
|
|
|
if err := r.iptablesClient.AppendUnique(tableMangle, chainPOSTROUTING, postRule...); err != nil {
|
|
merr = multierror.Append(merr, fmt.Errorf("add mangle postrouting rule: %w", err))
|
|
} else {
|
|
r.rules[markManglePost] = postRule
|
|
}
|
|
|
|
return nberrors.FormatErrorOrNil(merr)
|
|
}
|
|
|
|
func (r *router) cleanupDataPlaneMark() error {
|
|
var merr *multierror.Error
|
|
if preRule, exists := r.rules[markManglePre]; exists {
|
|
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPREROUTING, preRule...); err != nil {
|
|
merr = multierror.Append(merr, fmt.Errorf("remove mangle prerouting rule: %w", err))
|
|
} else {
|
|
delete(r.rules, markManglePre)
|
|
}
|
|
}
|
|
|
|
if postRule, exists := r.rules[markManglePost]; exists {
|
|
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPOSTROUTING, postRule...); err != nil {
|
|
merr = multierror.Append(merr, fmt.Errorf("remove mangle postrouting rule: %w", err))
|
|
} else {
|
|
delete(r.rules, markManglePost)
|
|
}
|
|
}
|
|
|
|
return nberrors.FormatErrorOrNil(merr)
|
|
}
|
|
|
|
func (r *router) addPostroutingRules() error {
|
|
// First rule for outbound masquerade
|
|
rule1 := []string{
|
|
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
|
|
"!", "-o", "lo",
|
|
"-j", routingFinalNatJump,
|
|
}
|
|
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule1...); err != nil {
|
|
return fmt.Errorf("add outbound masquerade rule: %v", err)
|
|
}
|
|
r.rules["static-nat-outbound"] = rule1
|
|
|
|
// Second rule for return traffic masquerade
|
|
rule2 := []string{
|
|
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
|
|
"-o", r.wgIface.Name(),
|
|
"-j", routingFinalNatJump,
|
|
}
|
|
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule2...); err != nil {
|
|
return fmt.Errorf("add return masquerade rule: %v", err)
|
|
}
|
|
r.rules["static-nat-return"] = rule2
|
|
|
|
return nil
|
|
}
|
|
|
|
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
|
func (r *router) addMSSClampingRules() error {
|
|
overhead := uint16(ipv4TCPHeaderSize)
|
|
if r.v6 {
|
|
overhead = ipv6TCPHeaderSize
|
|
}
|
|
mss := r.mtu - overhead
|
|
|
|
// Add jump rule from FORWARD chain in mangle table to our custom chain
|
|
jumpRule := []string{
|
|
"-j", chainRTMSSCLAMP,
|
|
}
|
|
if err := r.iptablesClient.Insert(tableMangle, chainFORWARD, 1, jumpRule...); err != nil {
|
|
return fmt.Errorf("add jump to MSS clamp chain: %w", err)
|
|
}
|
|
r.rules[jumpMSSClamp] = jumpRule
|
|
|
|
ruleOut := []string{
|
|
"-o", r.wgIface.Name(),
|
|
"-p", "tcp",
|
|
"--tcp-flags", "SYN,RST", "SYN",
|
|
"-j", "TCPMSS",
|
|
"--set-mss", fmt.Sprintf("%d", mss),
|
|
}
|
|
if err := r.iptablesClient.Append(tableMangle, chainRTMSSCLAMP, ruleOut...); err != nil {
|
|
return fmt.Errorf("add outbound MSS clamp rule: %w", err)
|
|
}
|
|
r.rules["mss-clamp-out"] = ruleOut
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *router) insertEstablishedRule(chain string) error {
|
|
establishedRule := getConntrackEstablished()
|
|
|
|
err := r.iptablesClient.Insert(tableFilter, chain, 1, establishedRule...)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to insert established rule: %v", err)
|
|
}
|
|
|
|
ruleKey := "established-" + chain
|
|
r.rules[ruleKey] = establishedRule
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *router) addJumpRules() error {
|
|
// Jump to nat chain
|
|
natRule := []string{"-j", chainRTNAT}
|
|
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
|
|
return fmt.Errorf("add nat postrouting jump rule: %v", err)
|
|
}
|
|
r.rules[jumpNatPost] = natRule
|
|
|
|
// Jump to mangle prerouting chain
|
|
preRule := []string{"-j", chainRTPRE}
|
|
if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil {
|
|
return fmt.Errorf("add mangle prerouting jump rule: %v", err)
|
|
}
|
|
r.rules[jumpManglePre] = preRule
|
|
|
|
// Jump to nat prerouting chain
|
|
rdrRule := []string{"-j", chainRTRDR}
|
|
if err := r.iptablesClient.Insert(tableNat, chainPREROUTING, 1, rdrRule...); err != nil {
|
|
return fmt.Errorf("add nat prerouting jump rule: %v", err)
|
|
}
|
|
r.rules[jumpNatPre] = rdrRule
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *router) cleanJumpRules() error {
|
|
for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre, jumpMSSClamp} {
|
|
if rule, exists := r.rules[ruleKey]; exists {
|
|
var table, chain string
|
|
switch ruleKey {
|
|
case jumpNatPost:
|
|
table = tableNat
|
|
chain = chainPOSTROUTING
|
|
case jumpManglePre:
|
|
table = tableMangle
|
|
chain = chainPREROUTING
|
|
case jumpNatPre:
|
|
table = tableNat
|
|
chain = chainPREROUTING
|
|
case jumpMSSClamp:
|
|
table = tableMangle
|
|
chain = chainFORWARD
|
|
default:
|
|
return fmt.Errorf("unknown jump rule: %s", ruleKey)
|
|
}
|
|
|
|
if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil {
|
|
return fmt.Errorf("delete rule from chain %s in table %s, err: %v", chain, table, err)
|
|
}
|
|
delete(r.rules, ruleKey)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
|
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
|
|
|
if rule, exists := r.rules[ruleKey]; exists {
|
|
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
|
|
return fmt.Errorf("error while removing existing marking rule for %s: %v", pair.Destination, err)
|
|
}
|
|
delete(r.rules, ruleKey)
|
|
}
|
|
|
|
markValue := nbnet.PreroutingFwmarkMasquerade
|
|
if pair.Inverse {
|
|
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
|
|
}
|
|
|
|
rule := []string{"-i", r.wgIface.Name()}
|
|
if pair.Inverse {
|
|
rule = []string{"!", "-i", r.wgIface.Name()}
|
|
}
|
|
|
|
rule = append(rule,
|
|
"-m", "conntrack",
|
|
"--ctstate", "NEW",
|
|
)
|
|
sourceExp, err := r.applyNetwork("-s", pair.Source, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("apply network -s: %w", err)
|
|
}
|
|
destExp, err := r.applyNetwork("-d", pair.Destination, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("apply network -d: %w", err)
|
|
}
|
|
|
|
rule = append(rule, sourceExp...)
|
|
rule = append(rule, destExp...)
|
|
rule = append(rule,
|
|
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
|
|
)
|
|
|
|
// Ensure nat rules come first, so the mark can be overwritten.
|
|
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
|
|
if err := r.iptablesClient.Insert(tableMangle, chainRTPRE, 1, rule...); err != nil {
|
|
// TODO: rollback ipset counter
|
|
return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err)
|
|
}
|
|
|
|
r.rules[ruleKey] = rule
|
|
|
|
r.updateState()
|
|
return nil
|
|
}
|
|
|
|
func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
|
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
|
|
|
if rule, exists := r.rules[ruleKey]; exists {
|
|
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
|
|
return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err)
|
|
}
|
|
delete(r.rules, ruleKey)
|
|
|
|
if err := r.decrementSetCounter(rule); err != nil {
|
|
return fmt.Errorf("decrement ipset counter: %w", err)
|
|
}
|
|
} else {
|
|
log.Debugf("marking rule %s not found", ruleKey)
|
|
}
|
|
|
|
r.updateState()
|
|
return nil
|
|
}
|
|
|
|
func (r *router) updateState() {
|
|
if r.stateManager == nil {
|
|
return
|
|
}
|
|
|
|
var currentState *ShutdownState
|
|
if existing := r.stateManager.GetState(currentState); existing != nil {
|
|
if existingState, ok := existing.(*ShutdownState); ok {
|
|
currentState = existingState
|
|
}
|
|
}
|
|
if currentState == nil {
|
|
currentState = &ShutdownState{}
|
|
}
|
|
|
|
currentState.Lock()
|
|
defer currentState.Unlock()
|
|
|
|
if r.v6 {
|
|
currentState.RouteRules6 = r.rules
|
|
currentState.RouteIPsetCounter6 = r.ipsetCounter
|
|
} else {
|
|
currentState.RouteRules = r.rules
|
|
currentState.RouteIPsetCounter = r.ipsetCounter
|
|
}
|
|
|
|
if err := r.stateManager.UpdateState(currentState); err != nil {
|
|
log.Errorf("failed to update state: %v", err)
|
|
}
|
|
}
|
|
|
|
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
|
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ruleKey := rule.ID()
|
|
if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
|
return rule, nil
|
|
}
|
|
|
|
toDestination := rule.TranslatedAddress.String()
|
|
switch {
|
|
case len(rule.TranslatedPort.Values) == 0:
|
|
// no translated port, use original port
|
|
case len(rule.TranslatedPort.Values) == 1:
|
|
toDestination += fmt.Sprintf(":%d", rule.TranslatedPort.Values[0])
|
|
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
|
|
// need the "/originalport" suffix to avoid dnat port randomization
|
|
toDestination += fmt.Sprintf(":%d-%d/%d", rule.TranslatedPort.Values[0], rule.TranslatedPort.Values[1], rule.DestinationPort.Values[0])
|
|
default:
|
|
return nil, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
|
|
}
|
|
|
|
proto := strings.ToLower(string(rule.Protocol))
|
|
|
|
rules := make(map[string]ruleInfo, 3)
|
|
|
|
// DNAT rule
|
|
dnatRule := []string{
|
|
"!", "-i", r.wgIface.Name(),
|
|
"-p", proto,
|
|
"-j", "DNAT",
|
|
"--to-destination", toDestination,
|
|
}
|
|
dnatRule = append(dnatRule, applyPort("--dport", &rule.DestinationPort)...)
|
|
rules[ruleKey+dnatSuffix] = ruleInfo{
|
|
table: tableNat,
|
|
chain: chainRTRDR,
|
|
rule: dnatRule,
|
|
}
|
|
|
|
// SNAT rule
|
|
snatRule := []string{
|
|
"-o", r.wgIface.Name(),
|
|
"-p", proto,
|
|
"-d", rule.TranslatedAddress.String(),
|
|
"-j", "MASQUERADE",
|
|
}
|
|
snatRule = append(snatRule, applyPort("--dport", &rule.TranslatedPort)...)
|
|
rules[ruleKey+snatSuffix] = ruleInfo{
|
|
table: tableNat,
|
|
chain: chainRTNAT,
|
|
rule: snatRule,
|
|
}
|
|
|
|
// Forward filtering rule, if fwd policy is DROP
|
|
forwardRule := []string{
|
|
"-o", r.wgIface.Name(),
|
|
"-p", proto,
|
|
"-d", rule.TranslatedAddress.String(),
|
|
"-j", "ACCEPT",
|
|
}
|
|
forwardRule = append(forwardRule, applyPort("--dport", &rule.TranslatedPort)...)
|
|
rules[ruleKey+fwdSuffix] = ruleInfo{
|
|
table: tableFilter,
|
|
chain: chainRTFWDOUT,
|
|
rule: forwardRule,
|
|
}
|
|
|
|
for key, ruleInfo := range rules {
|
|
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
|
|
if rollbackErr := r.rollbackRules(rules); rollbackErr != nil {
|
|
log.Errorf("rollback failed: %v", rollbackErr)
|
|
}
|
|
return nil, fmt.Errorf("add rule %s: %w", key, err)
|
|
}
|
|
r.rules[key] = ruleInfo.rule
|
|
}
|
|
|
|
r.updateState()
|
|
return rule, nil
|
|
}
|
|
|
|
func (r *router) rollbackRules(rules map[string]ruleInfo) error {
|
|
var merr *multierror.Error
|
|
for key, ruleInfo := range rules {
|
|
if err := r.iptablesClient.DeleteIfExists(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
|
|
merr = multierror.Append(merr, fmt.Errorf("rollback rule %s: %w", key, err))
|
|
// On rollback error, add to rules map for next cleanup
|
|
r.rules[key] = ruleInfo.rule
|
|
}
|
|
}
|
|
if merr != nil {
|
|
r.updateState()
|
|
}
|
|
return nberrors.FormatErrorOrNil(merr)
|
|
}
|
|
|
|
func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
|
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
|
log.Errorf("%v", err)
|
|
}
|
|
|
|
ruleKey := rule.ID()
|
|
|
|
var merr *multierror.Error
|
|
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
|
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
|
|
merr = multierror.Append(merr, fmt.Errorf("delete DNAT rule: %w", err))
|
|
}
|
|
delete(r.rules, ruleKey+dnatSuffix)
|
|
}
|
|
|
|
if snatRule, exists := r.rules[ruleKey+snatSuffix]; exists {
|
|
if err := r.iptablesClient.Delete(tableNat, chainRTNAT, snatRule...); err != nil {
|
|
merr = multierror.Append(merr, fmt.Errorf("delete SNAT rule: %w", err))
|
|
}
|
|
delete(r.rules, ruleKey+snatSuffix)
|
|
}
|
|
|
|
if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists {
|
|
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDOUT, fwdRule...); err != nil {
|
|
merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err))
|
|
}
|
|
delete(r.rules, ruleKey+fwdSuffix)
|
|
}
|
|
|
|
r.updateState()
|
|
return nberrors.FormatErrorOrNil(merr)
|
|
}
|
|
|
|
func (r *router) genRouteRuleSpec(params routeFilteringRuleParams, sources []netip.Prefix) ([]string, error) {
|
|
var rule []string
|
|
|
|
sourceExp, err := r.applyNetwork("-s", params.Source, sources)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("apply network -s: %w", err)
|
|
|
|
}
|
|
destExp, err := r.applyNetwork("-d", params.Destination, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("apply network -d: %w", err)
|
|
}
|
|
|
|
rule = append(rule, sourceExp...)
|
|
rule = append(rule, destExp...)
|
|
|
|
if params.Proto != firewall.ProtocolALL {
|
|
rule = append(rule, "-p", strings.ToLower(protoForFamily(params.Proto, r.v6)))
|
|
rule = append(rule, applyPort("--sport", params.SPort)...)
|
|
rule = append(rule, applyPort("--dport", params.DPort)...)
|
|
}
|
|
|
|
rule = append(rule, "-j", actionToStr(params.Action))
|
|
|
|
return rule, nil
|
|
}
|
|
|
|
func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []netip.Prefix) ([]string, error) {
|
|
direction := "src"
|
|
if flag == "-d" {
|
|
direction = "dst"
|
|
}
|
|
|
|
if network.IsSet() {
|
|
name := r.ipsetName(network.Set.HashedName())
|
|
if _, err := r.ipsetCounter.Increment(name, prefixes); err != nil {
|
|
return nil, fmt.Errorf("create or get ipset: %w", err)
|
|
}
|
|
|
|
return []string{"-m", "set", matchSet, name, direction}, nil
|
|
}
|
|
if network.IsPrefix() {
|
|
return []string{flag, network.Prefix.String()}, nil
|
|
}
|
|
|
|
// nolint:nilnil
|
|
return nil, nil
|
|
}
|
|
|
|
func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
|
name := r.ipsetName(set.HashedName())
|
|
var merr *multierror.Error
|
|
for _, prefix := range prefixes {
|
|
if err := r.addPrefixToIPSet(name, prefix); err != nil {
|
|
merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err))
|
|
}
|
|
}
|
|
if merr == nil {
|
|
log.Debugf("updated set %s with prefixes %v", name, prefixes)
|
|
}
|
|
|
|
return nberrors.FormatErrorOrNil(merr)
|
|
}
|
|
|
|
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
|
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
|
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
|
|
|
|
if _, exists := r.rules[ruleID]; exists {
|
|
return nil
|
|
}
|
|
|
|
dnatRule := []string{
|
|
"-i", r.wgIface.Name(),
|
|
"-p", strings.ToLower(protoForFamily(protocol, r.v6)),
|
|
"--dport", strconv.Itoa(int(originalPort)),
|
|
"-d", localAddr.String(),
|
|
"-m", "addrtype", "--dst-type", "LOCAL",
|
|
"-j", "DNAT",
|
|
"--to-destination", ":" + strconv.Itoa(int(translatedPort)),
|
|
}
|
|
|
|
ruleInfo := ruleInfo{
|
|
table: tableNat,
|
|
chain: chainRTRDR,
|
|
rule: dnatRule,
|
|
}
|
|
|
|
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
|
|
return fmt.Errorf("add inbound DNAT rule: %w", err)
|
|
}
|
|
r.rules[ruleID] = ruleInfo.rule
|
|
|
|
r.updateState()
|
|
return nil
|
|
}
|
|
|
|
// RemoveInboundDNAT removes an inbound DNAT rule.
|
|
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
|
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
|
|
|
|
if dnatRule, exists := r.rules[ruleID]; exists {
|
|
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
|
|
return fmt.Errorf("delete inbound DNAT rule: %w", err)
|
|
}
|
|
delete(r.rules, ruleID)
|
|
}
|
|
|
|
r.updateState()
|
|
return nil
|
|
}
|
|
|
|
// ensureNATOutputChain lazily creates the OUTPUT NAT chain and jump rule on first use.
|
|
func (r *router) ensureNATOutputChain() error {
|
|
if _, exists := r.rules[jumpNatOutput]; exists {
|
|
return nil
|
|
}
|
|
|
|
chainExists, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput)
|
|
if err != nil {
|
|
return fmt.Errorf("check chain %s: %w", chainNATOutput, err)
|
|
}
|
|
if !chainExists {
|
|
if err := r.iptablesClient.NewChain(tableNat, chainNATOutput); err != nil {
|
|
return fmt.Errorf("create chain %s: %w", chainNATOutput, err)
|
|
}
|
|
}
|
|
|
|
jumpRule := []string{"-j", chainNATOutput}
|
|
if err := r.iptablesClient.Insert(tableNat, "OUTPUT", 1, jumpRule...); err != nil {
|
|
if !chainExists {
|
|
if delErr := r.iptablesClient.ClearAndDeleteChain(tableNat, chainNATOutput); delErr != nil {
|
|
log.Warnf("failed to rollback chain %s: %v", chainNATOutput, delErr)
|
|
}
|
|
}
|
|
return fmt.Errorf("add OUTPUT jump rule: %w", err)
|
|
}
|
|
r.rules[jumpNatOutput] = jumpRule
|
|
|
|
r.updateState()
|
|
return nil
|
|
}
|
|
|
|
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
|
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
|
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
|
|
|
|
if _, exists := r.rules[ruleID]; exists {
|
|
return nil
|
|
}
|
|
|
|
if err := r.ensureNATOutputChain(); err != nil {
|
|
return err
|
|
}
|
|
|
|
dnatRule := []string{
|
|
"-p", strings.ToLower(string(protocol)),
|
|
"--dport", strconv.Itoa(int(originalPort)),
|
|
"-d", localAddr.String(),
|
|
"-j", "DNAT",
|
|
"--to-destination", ":" + strconv.Itoa(int(translatedPort)),
|
|
}
|
|
|
|
if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil {
|
|
return fmt.Errorf("add output DNAT rule: %w", err)
|
|
}
|
|
r.rules[ruleID] = dnatRule
|
|
|
|
r.updateState()
|
|
return nil
|
|
}
|
|
|
|
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
|
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
|
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
|
|
|
|
if dnatRule, exists := r.rules[ruleID]; exists {
|
|
if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil {
|
|
return fmt.Errorf("delete output DNAT rule: %w", err)
|
|
}
|
|
delete(r.rules, ruleID)
|
|
}
|
|
|
|
r.updateState()
|
|
return nil
|
|
}
|
|
|
|
func applyPort(flag string, port *firewall.Port) []string {
|
|
if port == nil {
|
|
return nil
|
|
}
|
|
|
|
if port.IsRange && len(port.Values) == 2 {
|
|
return []string{flag, fmt.Sprintf("%d:%d", port.Values[0], port.Values[1])}
|
|
}
|
|
|
|
if len(port.Values) > 1 {
|
|
portList := make([]string, len(port.Values))
|
|
for i, p := range port.Values {
|
|
portList[i] = strconv.Itoa(int(p))
|
|
}
|
|
return []string{"-m", "multiport", flag, strings.Join(portList, ",")}
|
|
}
|
|
|
|
return []string{flag, strconv.Itoa(int(port.Values[0]))}
|
|
}
|
|
|
|
// ipsetName returns the ipset name, suffixed with "-v6" for the v6 router
|
|
// to avoid collisions since ipsets are global in the kernel.
|
|
func (r *router) ipsetName(name string) string {
|
|
if r.v6 {
|
|
return name + "-v6"
|
|
}
|
|
return name
|
|
}
|
|
|
|
func (r *router) createIPSet(name string) error {
|
|
opts := ipset.CreateOptions{
|
|
Replace: true,
|
|
}
|
|
if r.v6 {
|
|
opts.Family = ipset.FamilyIPV6
|
|
}
|
|
|
|
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
|
return fmt.Errorf("create ipset %s: %w", name, err)
|
|
}
|
|
|
|
log.Debugf("created ipset %s with type hash:net", name)
|
|
return nil
|
|
}
|
|
|
|
func (r *router) addPrefixToIPSet(name string, prefix netip.Prefix) error {
|
|
addr := prefix.Addr()
|
|
ip := addr.AsSlice()
|
|
|
|
entry := &ipset.Entry{
|
|
IP: ip,
|
|
CIDR: uint8(prefix.Bits()),
|
|
Replace: true,
|
|
}
|
|
|
|
if err := ipset.Add(name, entry); err != nil {
|
|
return fmt.Errorf("add prefix to ipset %s: %w", name, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *router) destroyIPSet(name string) error {
|
|
return ipset.Destroy(name)
|
|
}
|