[client] Set up firewall rules for dns routes dynamically based on dns response (#3702)

This commit is contained in:
Viktor Liu
2025-04-24 17:37:28 +02:00
committed by GitHub
parent 85f92f8321
commit 4a9049566a
45 changed files with 1399 additions and 591 deletions

View File

@@ -18,7 +18,7 @@ func (r RuleID) ID() string {
func GenerateRouteRuleKey(
sources []netip.Prefix,
destination netip.Prefix,
destination manager.Network,
proto manager.Protocol,
sPort *manager.Port,
dPort *manager.Port,

View File

@@ -18,6 +18,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/management/domain"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
@@ -25,7 +26,7 @@ var ErrSourceRangesEmpty = errors.New("sources range is empty")
// Manager is a ACL rules manager
type Manager interface {
ApplyFiltering(networkMap *mgmProto.NetworkMap)
ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool)
}
type protoMatch struct {
@@ -53,7 +54,7 @@ func NewDefaultManager(fm firewall.Manager) *DefaultManager {
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
//
// If allowByDefault is true it appends allow ALL traffic rules to input and output chains.
func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) {
d.mutex.Lock()
defer d.mutex.Unlock()
@@ -82,7 +83,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
log.Errorf("failed to set legacy management flag: %v", err)
}
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules); err != nil {
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
log.Errorf("Failed to apply route ACLs: %v", err)
}
@@ -176,16 +177,16 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
d.peerRulesPairs = newRulePairs
}
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error {
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dynamicResolver bool) error {
newRouteRules := make(map[id.RuleID]struct{}, len(rules))
var merr *multierror.Error
// Apply new rules - firewall manager will return existing rule ID if already present
for _, rule := range rules {
id, err := d.applyRouteACL(rule)
id, err := d.applyRouteACL(rule, dynamicResolver)
if err != nil {
if errors.Is(err, ErrSourceRangesEmpty) {
log.Debugf("skipping empty rule with destination %s: %v", rule.Destination, err)
log.Debugf("skipping empty sources rule with destination %s: %v", rule.Destination, err)
} else {
merr = multierror.Append(merr, fmt.Errorf("add route rule: %w", err))
}
@@ -208,7 +209,7 @@ func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) err
return nberrors.FormatErrorOrNil(merr)
}
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) {
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamicResolver bool) (id.RuleID, error) {
if len(rule.SourceRanges) == 0 {
return "", ErrSourceRangesEmpty
}
@@ -222,15 +223,9 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.Rul
sources = append(sources, source)
}
var destination netip.Prefix
if rule.IsDynamic {
destination = getDefault(sources[0])
} else {
var err error
destination, err = netip.ParsePrefix(rule.Destination)
if err != nil {
return "", fmt.Errorf("parse destination: %w", err)
}
destination, err := determineDestination(rule, dynamicResolver, sources)
if err != nil {
return "", fmt.Errorf("determine destination: %w", err)
}
protocol, err := convertToFirewallProtocol(rule.Protocol)
@@ -580,6 +575,33 @@ func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port {
return nil
}
func determineDestination(rule *mgmProto.RouteFirewallRule, dynamicResolver bool, sources []netip.Prefix) (firewall.Network, error) {
var destination firewall.Network
if rule.IsDynamic {
if dynamicResolver {
if len(rule.Domains) > 0 {
destination.Set = firewall.NewDomainSet(domain.FromPunycodeList(rule.Domains))
} else {
// isDynamic is set but no domains = outdated management server
log.Warn("connected to an older version of management server (no domains in rules), using default destination")
destination.Prefix = getDefault(sources[0])
}
} else {
// client resolves DNS, we (router) don't know the destination
destination.Prefix = getDefault(sources[0])
}
return destination, nil
}
prefix, err := netip.ParsePrefix(rule.Destination)
if err != nil {
return destination, fmt.Errorf("parse destination: %w", err)
}
destination.Prefix = prefix
return destination, nil
}
func getDefault(prefix netip.Prefix) netip.Prefix {
if prefix.Addr().Is6() {
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)

View File

@@ -66,7 +66,7 @@ func TestDefaultManager(t *testing.T) {
acl := NewDefaultManager(fw)
t.Run("apply firewall rules", func(t *testing.T) {
acl.ApplyFiltering(networkMap)
acl.ApplyFiltering(networkMap, false)
if len(acl.peerRulesPairs) != 2 {
t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs)
@@ -92,7 +92,7 @@ func TestDefaultManager(t *testing.T) {
},
)
acl.ApplyFiltering(networkMap)
acl.ApplyFiltering(networkMap, false)
// we should have one old and one new rule in the existed rules
if len(acl.peerRulesPairs) != 2 {
@@ -116,13 +116,13 @@ func TestDefaultManager(t *testing.T) {
networkMap.FirewallRules = networkMap.FirewallRules[:0]
networkMap.FirewallRulesIsEmpty = true
if acl.ApplyFiltering(networkMap); len(acl.peerRulesPairs) != 0 {
if acl.ApplyFiltering(networkMap, false); len(acl.peerRulesPairs) != 0 {
t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs))
return
}
networkMap.FirewallRulesIsEmpty = false
acl.ApplyFiltering(networkMap)
acl.ApplyFiltering(networkMap, false)
if len(acl.peerRulesPairs) != 1 {
t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
return
@@ -359,7 +359,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
}(fw)
acl := NewDefaultManager(fw)
acl.ApplyFiltering(networkMap)
acl.ApplyFiltering(networkMap, false)
if len(acl.peerRulesPairs) != 3 {
t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))

View File

@@ -59,6 +59,16 @@ func collectIPTablesRules() (string, error) {
builder.WriteString("\n")
}
// Collect ipset information
ipsetOutput, err := collectIPSets()
if err != nil {
log.Warnf("Failed to collect ipset information: %v", err)
} else {
builder.WriteString("=== ipset list output ===\n")
builder.WriteString(ipsetOutput)
builder.WriteString("\n")
}
builder.WriteString("=== iptables -v -n -L output ===\n")
tables := []string{"filter", "nat", "mangle", "raw", "security"}
@@ -78,6 +88,28 @@ func collectIPTablesRules() (string, error) {
return builder.String(), nil
}
// collectIPSets collects information about ipsets
func collectIPSets() (string, error) {
cmd := exec.Command("ipset", "list")
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
if strings.Contains(err.Error(), "executable file not found") {
return "", fmt.Errorf("ipset command not found: %w", err)
}
return "", fmt.Errorf("execute ipset list: %w (stderr: %s)", err, stderr.String())
}
ipsets := stdout.String()
if strings.TrimSpace(ipsets) == "" {
return "No ipsets found", nil
}
return ipsets, nil
}
// collectIPTablesSave uses iptables-save to get rule definitions
func collectIPTablesSave() (string, error) {
cmd := exec.Command("iptables-save")

View File

@@ -3,6 +3,7 @@ package dnsfwd
import (
"context"
"errors"
"fmt"
"math"
"net"
"net/netip"
@@ -10,11 +11,16 @@ import (
"sync"
"time"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
)
const errResolveFailed = "failed to resolve query for domain=%s: %v"
@@ -23,25 +29,27 @@ const upstreamTimeout = 15 * time.Second
type DNSForwarder struct {
listenAddress string
ttl uint32
domains []string
statusRecorder *peer.Status
dnsServer *dns.Server
mux *dns.ServeMux
resId sync.Map
mutex sync.RWMutex
fwdEntries []*ForwarderEntry
firewall firewall.Manager
}
func NewDNSForwarder(listenAddress string, ttl uint32, statusRecorder *peer.Status) *DNSForwarder {
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewall.Manager, statusRecorder *peer.Status) *DNSForwarder {
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
return &DNSForwarder{
listenAddress: listenAddress,
ttl: ttl,
firewall: firewall,
statusRecorder: statusRecorder,
}
}
func (f *DNSForwarder) Listen(domains []string, resIds map[string]string) error {
func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
log.Infof("listen DNS forwarder on address=%s", f.listenAddress)
mux := dns.NewServeMux()
@@ -53,31 +61,35 @@ func (f *DNSForwarder) Listen(domains []string, resIds map[string]string) error
f.dnsServer = dnsServer
f.mux = mux
f.UpdateDomains(domains, resIds)
f.UpdateDomains(entries)
return dnsServer.ListenAndServe()
}
func (f *DNSForwarder) UpdateDomains(domains []string, resIds map[string]string) {
log.Debugf("Updating domains from %v to %v", f.domains, domains)
func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
f.mutex.Lock()
defer f.mutex.Unlock()
for _, d := range f.domains {
f.mux.HandleRemove(d)
if f.mux == nil {
log.Debug("DNS mux is nil, skipping domain update")
f.fwdEntries = entries
return
}
f.resId.Clear()
newDomains := filterDomains(domains)
oldDomains := filterDomains(f.fwdEntries)
for _, d := range oldDomains {
f.mux.HandleRemove(d.PunycodeString())
}
newDomains := filterDomains(entries)
for _, d := range newDomains {
f.mux.HandleFunc(d, f.handleDNSQuery)
f.mux.HandleFunc(d.PunycodeString(), f.handleDNSQuery)
}
for domain, resId := range resIds {
if domain != "" {
f.resId.Store(domain, resId)
}
}
f.fwdEntries = entries
f.domains = newDomains
log.Debugf("Updated domains from %v to %v", oldDomains, newDomains)
}
func (f *DNSForwarder) Close(ctx context.Context) error {
@@ -91,11 +103,11 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
if len(query.Question) == 0 {
return
}
log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v",
query.Question[0].Name, query.Question[0].Qtype, query.Question[0].Qclass)
question := query.Question[0]
domain := question.Name
log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v",
question.Name, question.Qtype, question.Qclass)
domain := strings.ToLower(question.Name)
resp := query.SetReply(query)
var network string
@@ -122,21 +134,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
return
}
resId := f.getResIdForDomain(strings.TrimSuffix(domain, "."))
if resId != "" {
for _, ip := range ips {
var ipWithSuffix string
if ip.Is4() {
ipWithSuffix = ip.String() + "/32"
log.Tracef("resolved domain=%s to IPv4=%s", domain, ipWithSuffix)
} else {
ipWithSuffix = ip.String() + "/128"
log.Tracef("resolved domain=%s to IPv6=%s", domain, ipWithSuffix)
}
f.statusRecorder.AddResolvedIPLookupEntry(ipWithSuffix, resId)
}
}
f.updateInternalState(domain, ips)
f.addIPsToResponse(resp, domain, ips)
if err := w.WriteMsg(resp); err != nil {
@@ -144,6 +142,42 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
}
}
func (f *DNSForwarder) updateInternalState(domain string, ips []netip.Addr) {
var prefixes []netip.Prefix
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
if mostSpecificResId != "" {
for _, ip := range ips {
var prefix netip.Prefix
if ip.Is4() {
prefix = netip.PrefixFrom(ip, 32)
} else {
prefix = netip.PrefixFrom(ip, 128)
}
prefixes = append(prefixes, prefix)
f.statusRecorder.AddResolvedIPLookupEntry(prefix, mostSpecificResId)
}
}
if f.firewall != nil {
f.updateFirewall(matchingEntries, prefixes)
}
}
func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixes []netip.Prefix) {
var merr *multierror.Error
for _, entry := range matchingEntries {
if err := f.firewall.UpdateSet(entry.Set, prefixes); err != nil {
merr = multierror.Append(merr, fmt.Errorf("update set for domain=%s: %w", entry.Domain, err))
}
}
if merr != nil {
log.Errorf("failed to update firewall sets (%d/%d): %v",
len(merr.Errors),
len(matchingEntries),
nberrors.FormatErrorOrNil(merr))
}
}
// handleDNSError processes DNS lookup errors and sends an appropriate error response
func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, resp *dns.Msg, domain string, err error) {
var dnsErr *net.DNSError
@@ -204,45 +238,53 @@ func (f *DNSForwarder) addIPsToResponse(resp *dns.Msg, domain string, ips []neti
}
}
func (f *DNSForwarder) getResIdForDomain(domain string) string {
var selectedResId string
// getMatchingEntries retrieves the resource IDs for a given domain.
// It returns the most specific match and all matching resource IDs.
func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*ForwarderEntry) {
var selectedResId route.ResID
var bestScore int
var matches []*ForwarderEntry
f.resId.Range(func(key, value interface{}) bool {
f.mutex.RLock()
defer f.mutex.RUnlock()
for _, entry := range f.fwdEntries {
var score int
pattern := key.(string)
pattern := entry.Domain.PunycodeString()
switch {
case strings.HasPrefix(pattern, "*."):
baseDomain := strings.TrimPrefix(pattern, "*.")
if domain == baseDomain || strings.HasSuffix(domain, "."+baseDomain) {
if strings.EqualFold(domain, baseDomain) || strings.HasSuffix(domain, "."+baseDomain) {
score = len(baseDomain)
matches = append(matches, entry)
}
case domain == pattern:
score = math.MaxInt
matches = append(matches, entry)
default:
return true
continue
}
if score > bestScore {
bestScore = score
selectedResId = value.(string)
selectedResId = entry.ResID
}
return true
})
}
return selectedResId
return selectedResId, matches
}
// filterDomains returns a list of normalized domains
func filterDomains(domains []string) []string {
newDomains := make([]string, 0, len(domains))
for _, d := range domains {
if d == "" {
func filterDomains(entries []*ForwarderEntry) domain.List {
newDomains := make(domain.List, 0, len(entries))
for _, d := range entries {
if d.Domain == "" {
log.Warn("empty domain in DNS forwarder")
continue
}
newDomains = append(newDomains, nbdns.NormalizeZone(d))
newDomains = append(newDomains, domain.Domain(nbdns.NormalizeZone(d.Domain.PunycodeString())))
}
return newDomains
}

View File

@@ -1,56 +1,61 @@
package dnsfwd
import (
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
)
func TestGetResIdForDomain(t *testing.T) {
func Test_getMatchingEntries(t *testing.T) {
testCases := []struct {
name string
storedMappings map[string]string // key: domain pattern, value: resId
storedMappings map[string]route.ResID // key: domain pattern, value: resId
queryDomain string
expectedResId string
expectedResId route.ResID
}{
{
name: "Empty map returns empty string",
storedMappings: map[string]string{},
storedMappings: map[string]route.ResID{},
queryDomain: "example.com",
expectedResId: "",
},
{
name: "Exact match returns stored resId",
storedMappings: map[string]string{"example.com": "res1"},
storedMappings: map[string]route.ResID{"example.com": "res1"},
queryDomain: "example.com",
expectedResId: "res1",
},
{
name: "Wildcard pattern matches base domain",
storedMappings: map[string]string{"*.example.com": "res2"},
storedMappings: map[string]route.ResID{"*.example.com": "res2"},
queryDomain: "example.com",
expectedResId: "res2",
},
{
name: "Wildcard pattern matches subdomain",
storedMappings: map[string]string{"*.example.com": "res3"},
storedMappings: map[string]route.ResID{"*.example.com": "res3"},
queryDomain: "foo.example.com",
expectedResId: "res3",
},
{
name: "Wildcard pattern does not match different domain",
storedMappings: map[string]string{"*.example.com": "res4"},
storedMappings: map[string]route.ResID{"*.example.com": "res4"},
queryDomain: "foo.notexample.com",
expectedResId: "",
},
{
name: "Non-wildcard pattern does not match subdomain",
storedMappings: map[string]string{"example.com": "res5"},
storedMappings: map[string]route.ResID{"example.com": "res5"},
queryDomain: "foo.example.com",
expectedResId: "",
},
{
name: "Exact match over overlapping wildcard",
storedMappings: map[string]string{
storedMappings: map[string]route.ResID{
"*.example.com": "resWildcard",
"foo.example.com": "resExact",
},
@@ -59,7 +64,7 @@ func TestGetResIdForDomain(t *testing.T) {
},
{
name: "Overlapping wildcards: Select more specific wildcard",
storedMappings: map[string]string{
storedMappings: map[string]route.ResID{
"*.example.com": "resA",
"*.sub.example.com": "resB",
},
@@ -68,7 +73,7 @@ func TestGetResIdForDomain(t *testing.T) {
},
{
name: "Wildcard multi-level subdomain match",
storedMappings: map[string]string{
storedMappings: map[string]route.ResID{
"*.example.com": "resMulti",
},
queryDomain: "a.b.example.com",
@@ -78,18 +83,21 @@ func TestGetResIdForDomain(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
fwd := &DNSForwarder{
resId: sync.Map{},
}
fwd := &DNSForwarder{}
var entries []*ForwarderEntry
for domainPattern, resId := range tc.storedMappings {
fwd.resId.Store(domainPattern, resId)
d, err := domain.FromString(domainPattern)
require.NoError(t, err)
entries = append(entries, &ForwarderEntry{
Domain: d,
ResID: resId,
})
}
fwd.UpdateDomains(entries)
got := fwd.getResIdForDomain(tc.queryDomain)
if got != tc.expectedResId {
t.Errorf("For query domain %q, expected resId %q, but got %q", tc.queryDomain, tc.expectedResId, got)
}
got, _ := fwd.getMatchingEntries(tc.queryDomain)
assert.Equal(t, got, tc.expectedResId)
})
}
}

View File

@@ -11,6 +11,8 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
)
const (
@@ -19,6 +21,13 @@ const (
dnsTTL = 60 //seconds
)
// ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list.
type ForwarderEntry struct {
Domain domain.Domain
ResID route.ResID
Set firewall.Set
}
type Manager struct {
firewall firewall.Manager
statusRecorder *peer.Status
@@ -34,7 +43,7 @@ func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager {
}
}
func (m *Manager) Start(domains []string, resIds map[string]string) error {
func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
log.Infof("starting DNS forwarder")
if m.dnsForwarder != nil {
return nil
@@ -44,9 +53,9 @@ func (m *Manager) Start(domains []string, resIds map[string]string) error {
return err
}
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, m.statusRecorder)
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, m.firewall, m.statusRecorder)
go func() {
if err := m.dnsForwarder.Listen(domains, resIds); err != nil {
if err := m.dnsForwarder.Listen(fwdEntries); err != nil {
// todo handle close error if it is exists
log.Errorf("failed to start DNS forwarder, err: %v", err)
}
@@ -55,12 +64,12 @@ func (m *Manager) Start(domains []string, resIds map[string]string) error {
return nil
}
func (m *Manager) UpdateDomains(domains []string, resIds map[string]string) {
func (m *Manager) UpdateDomains(entries []*ForwarderEntry) {
if m.dnsForwarder == nil {
return
}
m.dnsForwarder.UpdateDomains(domains, resIds)
m.dnsForwarder.UpdateDomains(entries)
}
func (m *Manager) Stop(ctx context.Context) error {
@@ -81,34 +90,34 @@ func (m *Manager) Stop(ctx context.Context) error {
return nberrors.FormatErrorOrNil(mErr)
}
func (h *Manager) allowDNSFirewall() error {
func (m *Manager) allowDNSFirewall() error {
dport := &firewall.Port{
IsRange: false,
Values: []uint16{ListenPort},
}
if h.firewall == nil {
if m.firewall == nil {
return nil
}
dnsRules, err := h.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "")
dnsRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "")
if err != nil {
log.Errorf("failed to add allow DNS router rules, err: %v", err)
return err
}
h.fwRules = dnsRules
m.fwRules = dnsRules
return nil
}
func (h *Manager) dropDNSFirewall() error {
func (m *Manager) dropDNSFirewall() error {
var mErr *multierror.Error
for _, rule := range h.fwRules {
if err := h.firewall.DeletePeerRule(rule); err != nil {
for _, rule := range m.fwRules {
if err := m.firewall.DeletePeerRule(rule); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
}
}
h.fwRules = nil
m.fwRules = nil
return nberrors.FormatErrorOrNil(mErr)
}

View File

@@ -527,7 +527,7 @@ func (e *Engine) blockLanAccess() {
if _, err := e.firewall.AddRouteFiltering(
nil,
[]netip.Prefix{v4},
network,
firewallManager.Network{Prefix: network},
firewallManager.ProtocolALL,
nil,
nil,
@@ -960,21 +960,21 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
}
}
// DNS forwarder
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
dnsRouteDomains, resourceIds := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), networkMap.GetRoutes())
e.updateDNSForwarder(dnsRouteFeatureFlag, dnsRouteDomains, resourceIds)
// apply routes first, route related actions might depend on routing being enabled
routes := toRoutes(networkMap.GetRoutes())
if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil {
log.Errorf("failed to update clientRoutes, err: %v", err)
}
// acls might need routing to be enabled, so we apply after routes
if e.acl != nil {
e.acl.ApplyFiltering(networkMap)
e.acl.ApplyFiltering(networkMap, dnsRouteFeatureFlag)
}
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
// Ingress forward rules
if err := e.updateForwardRules(networkMap.GetForwardingRules()); err != nil {
log.Errorf("failed to update forward rules, err: %v", err)
@@ -1079,29 +1079,24 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
return routes
}
func toRouteDomains(myPubKey string, protoRoutes []*mgmProto.Route) ([]string, map[string]string) {
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
}
var dnsRoutes []string
resIds := make(map[string]string)
for _, protoRoute := range protoRoutes {
if len(protoRoute.Domains) == 0 {
func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderEntry {
var entries []*dnsfwd.ForwarderEntry
for _, route := range routes {
if len(route.Domains) == 0 {
continue
}
if protoRoute.Peer == myPubKey {
dnsRoutes = append(dnsRoutes, protoRoute.Domains...)
// resource ID is the first part of the ID
resId := strings.Split(protoRoute.ID, ":")
for _, domain := range protoRoute.Domains {
if len(resId) > 0 {
resIds[domain] = resId[0]
}
if route.Peer == myPubKey {
domainSet := firewallManager.NewDomainSet(route.Domains)
for _, d := range route.Domains {
entries = append(entries, &dnsfwd.ForwarderEntry{
Domain: d,
Set: domainSet,
ResID: route.GetResourceID(),
})
}
}
}
return dnsRoutes, resIds
return entries
}
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.Config {
@@ -1751,7 +1746,10 @@ func (e *Engine) GetWgAddr() net.IP {
}
// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag
func (e *Engine) updateDNSForwarder(enabled bool, domains []string, resIds map[string]string) {
func (e *Engine) updateDNSForwarder(
enabled bool,
fwdEntries []*dnsfwd.ForwarderEntry,
) {
if !enabled {
if e.dnsForwardMgr == nil {
return
@@ -1762,18 +1760,18 @@ func (e *Engine) updateDNSForwarder(enabled bool, domains []string, resIds map[s
return
}
if len(domains) > 0 {
log.Infof("enable domain router service for domains: %v", domains)
if len(fwdEntries) > 0 {
if e.dnsForwardMgr == nil {
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
if err := e.dnsForwardMgr.Start(domains, resIds); err != nil {
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
log.Errorf("failed to start DNS forward: %v", err)
e.dnsForwardMgr = nil
}
log.Infof("started domain router service with %d entries", len(fwdEntries))
} else {
log.Infof("update domain router service for domains: %v", domains)
e.dnsForwardMgr.UpdateDomains(domains, resIds)
e.dnsForwardMgr.UpdateDomains(fwdEntries)
}
} else if e.dnsForwardMgr != nil {
log.Infof("disable domain router service")

View File

@@ -6,12 +6,14 @@ import (
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/route"
)
// routeEntry holds the route prefix and the corresponding resource ID.
type routeEntry struct {
prefix netip.Prefix
resourceID string
resourceID route.ResID
}
type routeIDLookup struct {
@@ -24,7 +26,7 @@ type routeIDLookup struct {
resolvedIPs sync.Map
}
func (r *routeIDLookup) AddLocalRouteID(resourceID string, route netip.Prefix) {
func (r *routeIDLookup) AddLocalRouteID(resourceID route.ResID, route netip.Prefix) {
r.localLock.Lock()
defer r.localLock.Unlock()
@@ -56,7 +58,7 @@ func (r *routeIDLookup) RemoveLocalRouteID(route netip.Prefix) {
}
}
func (r *routeIDLookup) AddRemoteRouteID(resourceID string, route netip.Prefix) {
func (r *routeIDLookup) AddRemoteRouteID(resourceID route.ResID, route netip.Prefix) {
r.remoteLock.Lock()
defer r.remoteLock.Unlock()
@@ -87,7 +89,7 @@ func (r *routeIDLookup) RemoveRemoteRouteID(route netip.Prefix) {
}
}
func (r *routeIDLookup) AddResolvedIP(resourceID string, route netip.Prefix) {
func (r *routeIDLookup) AddResolvedIP(resourceID route.ResID, route netip.Prefix) {
r.resolvedIPs.Store(route.Addr(), resourceID)
}
@@ -97,19 +99,19 @@ func (r *routeIDLookup) RemoveResolvedIP(route netip.Prefix) {
// Lookup returns the resource ID for the given IP address
// and a bool indicating if the IP is an exit node.
func (r *routeIDLookup) Lookup(ip netip.Addr) (string, bool) {
func (r *routeIDLookup) Lookup(ip netip.Addr) (route.ResID, bool) {
if res, ok := r.resolvedIPs.Load(ip); ok {
return res.(string), false
return res.(route.ResID), false
}
var resourceID string
var resourceID route.ResID
var isExitNode bool
r.localLock.RLock()
for _, entry := range r.localRoutes {
if entry.prefix.Contains(ip) {
resourceID = entry.resourceID
isExitNode = (entry.prefix.Bits() == 0)
isExitNode = entry.prefix.Bits() == 0
break
}
}
@@ -120,7 +122,7 @@ func (r *routeIDLookup) Lookup(ip netip.Addr) (string, bool) {
for _, entry := range r.remoteRoutes {
if entry.prefix.Contains(ip) {
resourceID = entry.resourceID
isExitNode = (entry.prefix.Bits() == 0)
isExitNode = entry.prefix.Bits() == 0
break
}
}

View File

@@ -21,6 +21,7 @@ import (
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/management/domain"
relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route"
)
const eventQueueSize = 10
@@ -313,7 +314,7 @@ func (d *Status) UpdatePeerState(receivedState State) error {
return nil
}
func (d *Status) AddPeerStateRoute(peer string, route string, resourceId string) error {
func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.ResID) error {
d.mux.Lock()
defer d.mux.Unlock()
@@ -581,7 +582,7 @@ func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
}
// AddLocalPeerStateRoute adds a route to the local peer state
func (d *Status) AddLocalPeerStateRoute(route, resourceId string) {
func (d *Status) AddLocalPeerStateRoute(route string, resourceId route.ResID) {
d.mux.Lock()
defer d.mux.Unlock()
@@ -611,14 +612,11 @@ func (d *Status) RemoveLocalPeerStateRoute(route string) {
}
// AddResolvedIPLookupEntry adds a resolved IP lookup entry
func (d *Status) AddResolvedIPLookupEntry(route, resourceId string) {
func (d *Status) AddResolvedIPLookupEntry(prefix netip.Prefix, resourceId route.ResID) {
d.mux.Lock()
defer d.mux.Unlock()
pref, err := netip.ParsePrefix(route)
if err == nil {
d.routeIDLookup.AddResolvedIP(resourceId, pref)
}
d.routeIDLookup.AddResolvedIP(resourceId, prefix)
}
// RemoveResolvedIPLookupEntry removes a resolved IP lookup entry
@@ -723,7 +721,7 @@ func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) {
d.nsGroupStates = dnsStates
}
func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix, resourceId string) {
func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix, resourceId route.ResID) {
d.mux.Lock()
defer d.mux.Unlock()

View File

@@ -234,7 +234,7 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
origPattern = writer.GetOrigPattern()
}
resolvedDomain := domain.Domain(r.Question[0].Name)
resolvedDomain := domain.Domain(strings.ToLower(r.Question[0].Name))
// already punycode via RegisterHandler()
originalDomain := domain.Domain(origPattern)
@@ -328,6 +328,11 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
// Update domain prefixes using resolved domain as key
if len(toAdd) > 0 || len(toRemove) > 0 {
if d.route.KeepRoute {
// replace stored prefixes with old + added
// nolint:gocritic
newPrefixes = append(oldPrefixes, toAdd...)
}
d.interceptedDomains[resolvedDomain] = newPrefixes
originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), "."))
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID())
@@ -338,7 +343,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
originalDomain.SafeString(),
toAdd)
}
if len(toRemove) > 0 {
if len(toRemove) > 0 && !d.route.KeepRoute {
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(),
originalDomain.SafeString(),

View File

@@ -259,8 +259,6 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
}
}
m.ctx = nil
m.mux.Lock()
defer m.mux.Unlock()
m.clientRoutes = nil
@@ -292,7 +290,7 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
return nil
}
if err := m.serverRouter.updateRoutes(newServerRoutesMap); err != nil {
if err := m.serverRouter.updateRoutes(newServerRoutesMap, useNewDNSRoute); err != nil {
return fmt.Errorf("update routes: %w", err)
}

View File

@@ -18,7 +18,7 @@ type serverRouter struct {
func (r serverRouter) cleanUp() {
}
func (r serverRouter) updateRoutes(map[route.ID]*route.Route) error {
func (r serverRouter) updateRoutes(map[route.ID]*route.Route, bool) error {
return nil
}

View File

@@ -35,7 +35,10 @@ func newServerRouter(ctx context.Context, wgInterface iface.WGIface, firewall fi
}, nil
}
func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error {
func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route, useNewDNSRoute bool) error {
m.mux.Lock()
defer m.mux.Unlock()
serverRoutesToRemove := make([]route.ID, 0)
for routeID := range m.routes {
@@ -73,7 +76,7 @@ func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error {
continue
}
err := m.addToServerNetwork(newRoute)
err := m.addToServerNetwork(newRoute, useNewDNSRoute)
if err != nil {
log.Errorf("Unable to add route %s from server, got: %v", newRoute.ID, err)
continue
@@ -90,57 +93,30 @@ func (m *serverRouter) removeFromServerNetwork(route *route.Route) error {
return m.ctx.Err()
}
m.mux.Lock()
defer m.mux.Unlock()
routerPair, err := routeToRouterPair(route)
if err != nil {
return fmt.Errorf("parse prefix: %w", err)
}
err = m.firewall.RemoveNatRule(routerPair)
if err != nil {
routerPair := routeToRouterPair(route, false)
if err := m.firewall.RemoveNatRule(routerPair); err != nil {
return fmt.Errorf("remove routing rules: %w", err)
}
delete(m.routes, route.ID)
routeStr := route.Network.String()
if route.IsDynamic() {
routeStr = route.Domains.SafeString()
}
m.statusRecorder.RemoveLocalPeerStateRoute(routeStr)
m.statusRecorder.RemoveLocalPeerStateRoute(route.NetString())
return nil
}
func (m *serverRouter) addToServerNetwork(route *route.Route) error {
func (m *serverRouter) addToServerNetwork(route *route.Route, useNewDNSRoute bool) error {
if m.ctx.Err() != nil {
log.Infof("Not adding to server network because context is done")
return m.ctx.Err()
}
m.mux.Lock()
defer m.mux.Unlock()
routerPair, err := routeToRouterPair(route)
if err != nil {
return fmt.Errorf("parse prefix: %w", err)
}
err = m.firewall.AddNatRule(routerPair)
if err != nil {
routerPair := routeToRouterPair(route, useNewDNSRoute)
if err := m.firewall.AddNatRule(routerPair); err != nil {
return fmt.Errorf("insert routing rules: %w", err)
}
m.routes[route.ID] = route
routeStr := route.Network.String()
if route.IsDynamic() {
routeStr = route.Domains.SafeString()
}
m.statusRecorder.AddLocalPeerStateRoute(routeStr, route.GetResourceID())
m.statusRecorder.AddLocalPeerStateRoute(route.NetString(), route.GetResourceID())
return nil
}
@@ -148,31 +124,29 @@ func (m *serverRouter) addToServerNetwork(route *route.Route) error {
func (m *serverRouter) cleanUp() {
m.mux.Lock()
defer m.mux.Unlock()
for _, r := range m.routes {
routerPair, err := routeToRouterPair(r)
if err != nil {
log.Errorf("Failed to convert route to router pair: %v", err)
continue
}
err = m.firewall.RemoveNatRule(routerPair)
if err != nil {
for _, r := range m.routes {
routerPair := routeToRouterPair(r, false)
if err := m.firewall.RemoveNatRule(routerPair); err != nil {
log.Errorf("Failed to remove cleanup route: %v", err)
}
}
m.statusRecorder.CleanLocalPeerStateRoutes()
}
func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) {
// TODO: add ipv6
func routeToRouterPair(route *route.Route, useNewDNSRoute bool) firewall.RouterPair {
source := getDefaultPrefix(route.Network)
destination := route.Network.Masked()
destination := firewall.Network{}
if route.IsDynamic() {
// TODO: add ipv6 additionally
destination = getDefaultPrefix(destination)
if useNewDNSRoute {
destination.Set = firewall.NewDomainSet(route.Domains)
} else {
// TODO: add ipv6 additionally
destination = getDefaultPrefix(destination.Prefix)
}
} else {
destination.Prefix = route.Network.Masked()
}
return firewall.RouterPair{
@@ -180,12 +154,16 @@ func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) {
Source: source,
Destination: destination,
Masquerade: route.Masquerade,
}, nil
}
}
func getDefaultPrefix(prefix netip.Prefix) netip.Prefix {
func getDefaultPrefix(prefix netip.Prefix) firewall.Network {
if prefix.Addr().Is6() {
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
return firewall.Network{
Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0),
}
}
return firewall.Network{
Prefix: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
}
return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
}

View File

@@ -45,7 +45,7 @@ var sysctlFailed bool
type ruleParams struct {
priority int
fwmark int
fwmark uint32
tableID int
family int
invert bool
@@ -55,8 +55,8 @@ type ruleParams struct {
func getSetupRules() []ruleParams {
return []ruleParams{
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"},
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, false, 0, "rule with suppress prefixlen v6"},
{100, 0, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"},
{100, 0, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, false, 0, "rule with suppress prefixlen v6"},
{110, nbnet.ControlPlaneMark, NetbirdVPNTableID, netlink.FAMILY_V4, true, -1, "rule v4 netbird"},
{110, nbnet.ControlPlaneMark, NetbirdVPNTableID, netlink.FAMILY_V6, true, -1, "rule v6 netbird"},
}