diff --git a/client/internal/routemanager/nftables_linux.go b/client/internal/routemanager/nftables_linux.go index dde8c143c..ca7d74f2a 100644 --- a/client/internal/routemanager/nftables_linux.go +++ b/client/internal/routemanager/nftables_linux.go @@ -19,6 +19,9 @@ const ( nftablesTable = "netbird-rt" nftablesRoutingForwardingChain = "netbird-rt-fwd" nftablesRoutingNatChain = "netbird-rt-nat" + + userDataAcceptForwardRuleSrc = "frwacceptsrc" + userDataAcceptForwardRuleDst = "frwacceptdst" ) // constants needed to create nftable rules @@ -71,25 +74,28 @@ var ( ) type nftablesManager struct { - ctx context.Context - stop context.CancelFunc - conn *nftables.Conn - tableIPv4 *nftables.Table - tableIPv6 *nftables.Table - chains map[string]map[string]*nftables.Chain - rules map[string]*nftables.Rule - mux sync.Mutex + ctx context.Context + stop context.CancelFunc + conn *nftables.Conn + tableIPv4 *nftables.Table + tableIPv6 *nftables.Table + chains map[string]map[string]*nftables.Chain + rules map[string]*nftables.Rule + filterTable *nftables.Table + defaultForwardRules []*nftables.Rule + mux sync.Mutex } func newNFTablesManager(parentCtx context.Context) (*nftablesManager, error) { ctx, cancel := context.WithCancel(parentCtx) mgr := &nftablesManager{ - ctx: ctx, - stop: cancel, - conn: &nftables.Conn{}, - chains: make(map[string]map[string]*nftables.Chain), - rules: make(map[string]*nftables.Rule), + ctx: ctx, + stop: cancel, + conn: &nftables.Conn{}, + chains: make(map[string]map[string]*nftables.Chain), + rules: make(map[string]*nftables.Rule), + defaultForwardRules: make([]*nftables.Rule, 2), } err := mgr.isSupported() @@ -97,6 +103,11 @@ func newNFTablesManager(parentCtx context.Context) (*nftablesManager, error) { return nil, err } + err = mgr.readFilterTable() + if err != nil { + return nil, err + } + return mgr, nil } @@ -109,6 +120,13 @@ func (n *nftablesManager) CleanRoutingRules() { n.conn.FlushTable(n.tableIPv6) n.conn.FlushTable(n.tableIPv4) } + + if n.defaultForwardRules[0] != nil { + err := n.eraseDefaultForwardRule() + if err != nil { + log.Errorf("failed to delete forward rule: %s", err) + } + } log.Debugf("flushing tables result in: %v error", n.conn.Flush()) } @@ -241,6 +259,112 @@ func (n *nftablesManager) refreshRulesMap() error { return nil } +func (n *nftablesManager) readFilterTable() error { + tables, err := n.conn.ListTables() + if err != nil { + return err + } + + for _, t := range tables { + if t.Name == "filter" { + n.filterTable = t + return nil + } + } + return nil +} + +func (n *nftablesManager) eraseDefaultForwardRule() error { + if n.defaultForwardRules[0] == nil { + return nil + } + + err := n.refreshDefaultForwardRule() + if err != nil { + return err + } + + for i, r := range n.defaultForwardRules { + err = n.conn.DelRule(r) + if err != nil { + log.Errorf("failed to delete forward rule (%d): %s", i, err) + } + n.defaultForwardRules[i] = nil + } + return nil +} + +func (n *nftablesManager) refreshDefaultForwardRule() error { + rules, err := n.conn.GetRules(n.defaultForwardRules[0].Table, n.defaultForwardRules[0].Chain) + if err != nil { + return fmt.Errorf("unable to list rules in forward chain: %s", err) + } + + found := false + for i, r := range n.defaultForwardRules { + for _, rule := range rules { + if string(rule.UserData) == string(r.UserData) { + n.defaultForwardRules[i] = rule + found = true + break + } + } + } + if !found { + return fmt.Errorf("unable to find forward accept rule") + } + + return nil +} + +func (n *nftablesManager) acceptForwardRule(sourceNetwork string) error { + src := generateCIDRMatcherExpressions("source", sourceNetwork) + dst := generateCIDRMatcherExpressions("destination", "0.0.0.0/0") + + var exprs []expr.Any + exprs = append(src, append(dst, &expr.Verdict{ + Kind: expr.VerdictAccept, + })...) + + r := &nftables.Rule{ + Table: n.filterTable, + Chain: &nftables.Chain{ + Name: "FORWARD", + Table: n.filterTable, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityFilter, + }, + Exprs: exprs, + UserData: []byte(userDataAcceptForwardRuleSrc), + } + + n.defaultForwardRules[0] = n.conn.AddRule(r) + + src = generateCIDRMatcherExpressions("source", "0.0.0.0/0") + dst = generateCIDRMatcherExpressions("destination", sourceNetwork) + + exprs = append(src, append(dst, &expr.Verdict{ + Kind: expr.VerdictAccept, + })...) + + r = &nftables.Rule{ + Table: n.filterTable, + Chain: &nftables.Chain{ + Name: "FORWARD", + Table: n.filterTable, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityFilter, + }, + Exprs: exprs, + UserData: []byte(userDataAcceptForwardRuleDst), + } + + n.defaultForwardRules[1] = n.conn.AddRule(r) + return nil +} + // checkOrCreateDefaultForwardingRules checks if the default forwarding rules are enabled func (n *nftablesManager) checkOrCreateDefaultForwardingRules() { _, foundIPv4 := n.rules[ipv4Forwarding] @@ -294,6 +418,14 @@ func (n *nftablesManager) InsertRoutingRules(pair routerPair) error { } } + if n.defaultForwardRules[0] == nil && n.filterTable != nil { + err = n.acceptForwardRule(pair.source) + if err != nil { + log.Errorf("unable to create default forward rule: %s", err) + } + log.Debugf("default accept forward rule added") + } + err = n.conn.Flush() if err != nil { return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err) @@ -374,6 +506,13 @@ func (n *nftablesManager) RemoveRoutingRules(pair routerPair) error { return err } + if len(n.rules) == 2 && n.defaultForwardRules[0] != nil { + err := n.eraseDefaultForwardRule() + if err != nil { + log.Errorf("failed to delte default fwd rule: %s", err) + } + } + err = n.conn.Flush() if err != nil { return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err)