[client] Set up firewall rules for dns routes dynamically based on dns response (#3702)

This commit is contained in:
Viktor Liu
2025-04-24 17:37:28 +02:00
committed by GitHub
parent 85f92f8321
commit 4a9049566a
45 changed files with 1399 additions and 591 deletions

View File

@@ -10,7 +10,6 @@ import (
"strings"
"github.com/coreos/go-iptables/iptables"
"github.com/davecgh/go-spew/spew"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
@@ -44,9 +43,14 @@ const (
const refreshRulesMapError = "refresh rules map: %w"
var (
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
errFilterTableNotFound = fmt.Errorf("'filter' table not found")
)
type setInput struct {
set firewall.Set
prefixes []netip.Prefix
}
type router struct {
conn *nftables.Conn
workTable *nftables.Table
@@ -54,7 +58,7 @@ type router struct {
chains map[string]*nftables.Chain
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
rules map[string]*nftables.Rule
ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set]
ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set]
wgIface iFaceMapper
ipFwdState *ipfwdstate.IPForwardingState
@@ -163,7 +167,7 @@ func (r *router) removeNatPreroutingRules() error {
func (r *router) loadFilterTable() (*nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
return nil, fmt.Errorf("unable to list tables: %v", err)
}
for _, table := range tables {
@@ -316,7 +320,7 @@ func (r *router) setupDataPlaneMark() error {
func (r *router) AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination netip.Prefix,
destination firewall.Network,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
@@ -331,23 +335,29 @@ func (r *router) AddRouteFiltering(
chain := r.chains[chainNameRoutingFw]
var exprs []expr.Any
var source firewall.Network
switch {
case len(sources) == 1 && sources[0].Bits() == 0:
// If it's 0.0.0.0/0, we don't need to add any source matching
case len(sources) == 1:
// If there's only one source, we can use it directly
exprs = append(exprs, generateCIDRMatcherExpressions(true, sources[0])...)
source.Prefix = sources[0]
default:
// If there are multiple sources, create or get an ipset
var err error
exprs, err = r.getIpSetExprs(sources, exprs)
if err != nil {
return nil, fmt.Errorf("get ipset expressions: %w", err)
}
// If there are multiple sources, use a set
source.Set = firewall.NewPrefixSet(sources)
}
// Handle destination
exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...)
sourceExp, err := r.applyNetwork(source, sources, true)
if err != nil {
return nil, fmt.Errorf("apply source: %w", err)
}
exprs = append(exprs, sourceExp...)
destExp, err := r.applyNetwork(destination, nil, false)
if err != nil {
return nil, fmt.Errorf("apply destination: %w", err)
}
exprs = append(exprs, destExp...)
// Handle protocol
if proto != firewall.ProtocolALL {
@@ -391,39 +401,27 @@ func (r *router) AddRouteFiltering(
rule = r.conn.AddRule(rule)
}
log.Tracef("Adding route rule %s", spew.Sdump(rule))
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf(flushError, err)
}
r.rules[string(ruleKey)] = rule
log.Debugf("nftables: added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action)
log.Debugf("added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action)
return ruleKey, nil
}
func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) {
setName := firewall.GenerateSetName(sources)
ref, err := r.ipsetCounter.Increment(setName, sources)
func (r *router) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bool) ([]expr.Any, error) {
ref, err := r.ipsetCounter.Increment(set.HashedName(), setInput{
set: set,
prefixes: prefixes,
})
if err != nil {
return nil, fmt.Errorf("create or get ipset for sources: %w", err)
return nil, fmt.Errorf("create or get ipset: %w", err)
}
exprs = append(exprs,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Lookup{
SourceRegister: 1,
SetName: ref.Out.Name,
SetID: ref.Out.ID,
},
)
return exprs, nil
return getIpSetExprs(ref, isSource)
}
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
@@ -442,42 +440,54 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
return fmt.Errorf("route rule %s has no handle", ruleKey)
}
setName := r.findSetNameInRule(nftRule)
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
return fmt.Errorf("delete: %w", err)
}
if setName != "" {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
return fmt.Errorf("decrement ipset reference: %w", err)
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
if err := r.decrementSetCounter(nftRule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
return nil
}
func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) {
func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, error) {
// overlapping prefixes will result in an error, so we need to merge them
sources = firewall.MergeIPRanges(sources)
prefixes := firewall.MergeIPRanges(input.prefixes)
set := &nftables.Set{
Name: setName,
Table: r.workTable,
nfset := &nftables.Set{
Name: setName,
Comment: input.set.Comment(),
Table: r.workTable,
// required for prefixes
Interval: true,
KeyType: nftables.TypeIPAddr,
}
elements := convertPrefixesToSet(prefixes)
if err := r.conn.AddSet(nfset, elements); err != nil {
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
}
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err)
}
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
return nfset, nil
}
func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
var elements []nftables.SetElement
for _, prefix := range sources {
for _, prefix := range prefixes {
// TODO: Implement IPv6 support
if prefix.Addr().Is6() {
log.Printf("Skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
continue
}
@@ -493,18 +503,7 @@ func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
)
}
if err := r.conn.AddSet(set, elements); err != nil {
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
}
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err)
}
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
return set, nil
return elements
}
// calculateLastIP determines the last IP in a given prefix.
@@ -528,8 +527,8 @@ func uint32ToBytes(ip uint32) [4]byte {
return b
}
func (r *router) deleteIpSet(setName string, set *nftables.Set) error {
r.conn.DelSet(set)
func (r *router) deleteIpSet(setName string, nfset *nftables.Set) error {
r.conn.DelSet(nfset)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
@@ -538,13 +537,27 @@ func (r *router) deleteIpSet(setName string, set *nftables.Set) error {
return nil
}
func (r *router) findSetNameInRule(rule *nftables.Rule) string {
for _, e := range rule.Exprs {
if lookup, ok := e.(*expr.Lookup); ok {
return lookup.SetName
func (r *router) decrementSetCounter(rule *nftables.Rule) error {
sets := r.findSets(rule)
var merr *multierror.Error
for _, setName := range sets {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
merr = multierror.Append(merr, fmt.Errorf("decrement set counter: %w", err))
}
}
return ""
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) findSets(rule *nftables.Rule) []string {
var sets []string
for _, e := range rule.Exprs {
if lookup, ok := e.(*expr.Lookup); ok {
sets = append(sets, lookup.SetName)
}
}
return sets
}
func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
@@ -586,7 +599,8 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("nftables: insert rules for %s: %v", pair.Destination, err)
// TODO: rollback ipset counter
return fmt.Errorf("insert rules for %s: %v", pair.Destination, err)
}
return nil
@@ -594,19 +608,22 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
// addNatRule inserts a nftables rule to the conn client flush queue
func (r *router) addNatRule(pair firewall.RouterPair) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
if err != nil {
return fmt.Errorf("apply source: %w", err)
}
destExp, err := r.applyNetwork(pair.Destination, nil, false)
if err != nil {
return fmt.Errorf("apply destination: %w", err)
}
op := expr.CmpOpEq
if pair.Inverse {
op = expr.CmpOpNeq
}
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
exprs := getCtNewExprs()
exprs = append(exprs,
// interface matching
exprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
@@ -616,7 +633,10 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
Register: 1,
Data: ifname(r.wgIface.Name()),
},
)
}
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
exprs = append(exprs, getCtNewExprs()...)
exprs = append(exprs, sourceExp...)
exprs = append(exprs, destExp...)
@@ -729,8 +749,15 @@ func (r *router) addPostroutingRules() error {
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
if err != nil {
return fmt.Errorf("apply source: %w", err)
}
destExp, err := r.applyNetwork(pair.Destination, nil, false)
if err != nil {
return fmt.Errorf("apply destination: %w", err)
}
exprs := []expr.Any{
&expr.Counter{},
@@ -739,7 +766,8 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
},
}
expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic
exprs = append(exprs, sourceExp...)
exprs = append(exprs, destExp...)
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
@@ -752,7 +780,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Exprs: expression,
Exprs: exprs,
UserData: []byte(ruleKey),
})
return nil
@@ -767,11 +795,13 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
log.Debugf("nftables: removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
} else {
log.Debugf("nftables: legacy forwarding rule %s not found", ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
}
return nil
@@ -982,12 +1012,14 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
return fmt.Errorf(refreshRulesMapError, err)
}
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove prerouting rule: %w", err)
}
if pair.Masquerade {
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove prerouting rule: %w", err)
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse prerouting rule: %w", err)
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse prerouting rule: %w", err)
}
}
if err := r.removeLegacyRouteRule(pair); err != nil {
@@ -995,10 +1027,10 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
// TODO: rollback set counter
return fmt.Errorf("remove nat rules rule %s: %v", pair.Destination, err)
}
log.Debugf("nftables: removed nat rules for %s", pair.Destination)
return nil
}
@@ -1006,16 +1038,19 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
err := r.conn.DelRule(rule)
if err != nil {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
log.Debugf("nftables: removed prerouting rule %s -> %s", pair.Source, pair.Destination)
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
} else {
log.Debugf("nftables: prerouting rule %s not found", ruleKey)
log.Debugf("prerouting rule %s not found", ruleKey)
}
return nil
@@ -1027,7 +1062,7 @@ func (r *router) refreshRulesMap() error {
for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("nftables: unable to list rules: %v", err)
return fmt.Errorf(" unable to list rules: %v", err)
}
for _, rule := range rules {
if len(rule.UserData) > 0 {
@@ -1301,13 +1336,54 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
return nberrors.FormatErrorOrNil(merr)
}
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any {
var offset uint32
if source {
offset = 12 // src offset
} else {
offset = 16 // dst offset
func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
nfset, err := r.conn.GetSetByName(r.workTable, set.HashedName())
if err != nil {
return fmt.Errorf("get set %s: %w", set.HashedName(), err)
}
elements := convertPrefixesToSet(prefixes)
if err := r.conn.SetAddElements(nfset, elements); err != nil {
return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
return nil
}
// applyNetwork generates nftables expressions for networks (CIDR) or sets
func (r *router) applyNetwork(
network firewall.Network,
setPrefixes []netip.Prefix,
isSource bool,
) ([]expr.Any, error) {
if network.IsSet() {
exprs, err := r.getIpSet(network.Set, setPrefixes, isSource)
if err != nil {
return nil, fmt.Errorf("source: %w", err)
}
return exprs, nil
}
if network.IsPrefix() {
return applyPrefix(network.Prefix, isSource), nil
}
return nil, nil
}
// applyPrefix generates nftables expressions for a CIDR prefix
func applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any {
// dst offset
offset := uint32(16)
if isSource {
// src offset
offset = 12
}
ones := prefix.Bits()
@@ -1415,3 +1491,27 @@ func getCtNewExprs() []expr.Any {
},
}
}
func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
// dst offset
offset := uint32(16)
if isSource {
// src offset
offset = 12
}
return []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offset,
Len: 4,
},
&expr.Lookup{
SourceRegister: 1,
SetName: ref.Out.Name,
SetID: ref.Out.ID,
},
}, nil
}