mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
Merge branch 'main' into refactor/getaccount-raw
This commit is contained in:
@@ -50,6 +50,12 @@ const (
|
|||||||
|
|
||||||
var errNatNotSupported = errors.New("nat not supported with userspace firewall")
|
var errNatNotSupported = errors.New("nat not supported with userspace firewall")
|
||||||
|
|
||||||
|
// serviceKey represents a protocol/port combination for netstack service registry
|
||||||
|
type serviceKey struct {
|
||||||
|
protocol gopacket.LayerType
|
||||||
|
port uint16
|
||||||
|
}
|
||||||
|
|
||||||
// RuleSet is a set of rules grouped by a string key
|
// RuleSet is a set of rules grouped by a string key
|
||||||
type RuleSet map[string]PeerRule
|
type RuleSet map[string]PeerRule
|
||||||
|
|
||||||
@@ -113,6 +119,9 @@ type Manager struct {
|
|||||||
portDNATEnabled atomic.Bool
|
portDNATEnabled atomic.Bool
|
||||||
portDNATRules []portDNATRule
|
portDNATRules []portDNATRule
|
||||||
portDNATMutex sync.RWMutex
|
portDNATMutex sync.RWMutex
|
||||||
|
|
||||||
|
netstackServices map[serviceKey]struct{}
|
||||||
|
netstackServiceMutex sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// decoder for packages
|
// decoder for packages
|
||||||
@@ -203,6 +212,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
localForwarding: enableLocalForwarding,
|
localForwarding: enableLocalForwarding,
|
||||||
dnatMappings: make(map[netip.Addr]netip.Addr),
|
dnatMappings: make(map[netip.Addr]netip.Addr),
|
||||||
portDNATRules: []portDNATRule{},
|
portDNATRules: []portDNATRule{},
|
||||||
|
netstackServices: make(map[serviceKey]struct{}),
|
||||||
}
|
}
|
||||||
m.routingEnabled.Store(false)
|
m.routingEnabled.Store(false)
|
||||||
|
|
||||||
@@ -838,9 +848,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// If requested we pass local traffic to internal interfaces to the forwarder.
|
if m.shouldForward(d, dstIP) {
|
||||||
// netstack doesn't have an interface to forward packets to the native stack so we always need to use the forwarder.
|
|
||||||
if m.localForwarding && (m.netstack || dstIP != m.wgIface.Address().IP) {
|
|
||||||
return m.handleForwardedLocalTraffic(packetData)
|
return m.handleForwardedLocalTraffic(packetData)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1274,3 +1282,86 @@ func (m *Manager) DisableRouting() error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterNetstackService registers a service as listening on the netstack for the given protocol and port
|
||||||
|
func (m *Manager) RegisterNetstackService(protocol nftypes.Protocol, port uint16) {
|
||||||
|
m.netstackServiceMutex.Lock()
|
||||||
|
defer m.netstackServiceMutex.Unlock()
|
||||||
|
layerType := m.protocolToLayerType(protocol)
|
||||||
|
key := serviceKey{protocol: layerType, port: port}
|
||||||
|
m.netstackServices[key] = struct{}{}
|
||||||
|
m.logger.Debug3("RegisterNetstackService: registered %s:%d (layerType=%s)", protocol, port, layerType)
|
||||||
|
m.logger.Debug1("RegisterNetstackService: current registry size: %d", len(m.netstackServices))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnregisterNetstackService removes a service from the netstack registry
|
||||||
|
func (m *Manager) UnregisterNetstackService(protocol nftypes.Protocol, port uint16) {
|
||||||
|
m.netstackServiceMutex.Lock()
|
||||||
|
defer m.netstackServiceMutex.Unlock()
|
||||||
|
layerType := m.protocolToLayerType(protocol)
|
||||||
|
key := serviceKey{protocol: layerType, port: port}
|
||||||
|
delete(m.netstackServices, key)
|
||||||
|
m.logger.Debug2("Unregistered netstack service on protocol %s port %d", protocol, port)
|
||||||
|
}
|
||||||
|
|
||||||
|
// protocolToLayerType converts nftypes.Protocol to gopacket.LayerType for internal use
|
||||||
|
func (m *Manager) protocolToLayerType(protocol nftypes.Protocol) gopacket.LayerType {
|
||||||
|
switch protocol {
|
||||||
|
case nftypes.TCP:
|
||||||
|
return layers.LayerTypeTCP
|
||||||
|
case nftypes.UDP:
|
||||||
|
return layers.LayerTypeUDP
|
||||||
|
case nftypes.ICMP:
|
||||||
|
return layers.LayerTypeICMPv4
|
||||||
|
default:
|
||||||
|
return gopacket.LayerType(0) // Invalid/unknown
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// shouldForward determines if a packet should be forwarded to the forwarder.
|
||||||
|
// The forwarder handles routing packets to the native OS network stack.
|
||||||
|
// Returns true if packet should go to the forwarder, false if it should go to netstack listeners or the native stack directly.
|
||||||
|
func (m *Manager) shouldForward(d *decoder, dstIP netip.Addr) bool {
|
||||||
|
// not enabled, never forward
|
||||||
|
if !m.localForwarding {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// netstack always needs to forward because it's lacking a native interface
|
||||||
|
// exception for registered netstack services, those should go to netstack listeners
|
||||||
|
if m.netstack {
|
||||||
|
return !m.hasMatchingNetstackService(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
// traffic to our other local interfaces (not NetBird IP) - always forward
|
||||||
|
if dstIP != m.wgIface.Address().IP {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// traffic to our NetBird IP, not netstack mode - send to netstack listeners
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasMatchingNetstackService checks if there's a registered netstack service for this packet
|
||||||
|
func (m *Manager) hasMatchingNetstackService(d *decoder) bool {
|
||||||
|
if len(d.decoded) < 2 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
var dstPort uint16
|
||||||
|
switch d.decoded[1] {
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
dstPort = uint16(d.tcp.DstPort)
|
||||||
|
case layers.LayerTypeUDP:
|
||||||
|
dstPort = uint16(d.udp.DstPort)
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
key := serviceKey{protocol: d.decoded[1], port: dstPort}
|
||||||
|
m.netstackServiceMutex.RLock()
|
||||||
|
_, exists := m.netstackServices[key]
|
||||||
|
m.netstackServiceMutex.RUnlock()
|
||||||
|
|
||||||
|
return exists
|
||||||
|
}
|
||||||
|
|||||||
@@ -83,4 +83,3 @@ func TestCacheMiss(t *testing.T) {
|
|||||||
t.Fatalf("expected cache miss, got=%v ok=%v", got, ok)
|
t.Fatalf("expected cache miss, got=%v ok=%v", got, ok)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
@@ -33,7 +34,7 @@ type firewaller interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type DNSForwarder struct {
|
type DNSForwarder struct {
|
||||||
listenAddress string
|
listenAddress netip.AddrPort
|
||||||
ttl uint32
|
ttl uint32
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
|
|
||||||
@@ -47,9 +48,11 @@ type DNSForwarder struct {
|
|||||||
firewall firewaller
|
firewall firewaller
|
||||||
resolver resolver
|
resolver resolver
|
||||||
cache *cache
|
cache *cache
|
||||||
|
|
||||||
|
wgIface wgIface
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
|
func NewDNSForwarder(listenAddress netip.AddrPort, ttl uint32, firewall firewaller, statusRecorder *peer.Status, wgIface wgIface) *DNSForwarder {
|
||||||
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
|
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
|
||||||
return &DNSForwarder{
|
return &DNSForwarder{
|
||||||
listenAddress: listenAddress,
|
listenAddress: listenAddress,
|
||||||
@@ -58,30 +61,46 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat
|
|||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
resolver: net.DefaultResolver,
|
resolver: net.DefaultResolver,
|
||||||
cache: newCache(),
|
cache: newCache(),
|
||||||
|
wgIface: wgIface,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
|
func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
|
||||||
log.Infof("starting DNS forwarder on address=%s", f.listenAddress)
|
var netstackNet *netstack.Net
|
||||||
|
if f.wgIface != nil {
|
||||||
|
netstackNet = f.wgIface.GetNet()
|
||||||
|
}
|
||||||
|
|
||||||
|
addrDesc := f.listenAddress.String()
|
||||||
|
if netstackNet != nil {
|
||||||
|
addrDesc = fmt.Sprintf("netstack %s", f.listenAddress)
|
||||||
|
}
|
||||||
|
log.Infof("starting DNS forwarder on address=%s", addrDesc)
|
||||||
|
|
||||||
|
udpLn, err := f.createUDPListener(netstackNet)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create UDP listener: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tcpLn, err := f.createTCPListener(netstackNet)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create TCP listener: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// UDP server
|
|
||||||
mux := dns.NewServeMux()
|
mux := dns.NewServeMux()
|
||||||
f.mux = mux
|
f.mux = mux
|
||||||
mux.HandleFunc(".", f.handleDNSQueryUDP)
|
mux.HandleFunc(".", f.handleDNSQueryUDP)
|
||||||
f.dnsServer = &dns.Server{
|
f.dnsServer = &dns.Server{
|
||||||
Addr: f.listenAddress,
|
PacketConn: udpLn,
|
||||||
Net: "udp",
|
Handler: mux,
|
||||||
Handler: mux,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TCP server
|
|
||||||
tcpMux := dns.NewServeMux()
|
tcpMux := dns.NewServeMux()
|
||||||
f.tcpMux = tcpMux
|
f.tcpMux = tcpMux
|
||||||
tcpMux.HandleFunc(".", f.handleDNSQueryTCP)
|
tcpMux.HandleFunc(".", f.handleDNSQueryTCP)
|
||||||
f.tcpServer = &dns.Server{
|
f.tcpServer = &dns.Server{
|
||||||
Addr: f.listenAddress,
|
Listener: tcpLn,
|
||||||
Net: "tcp",
|
Handler: tcpMux,
|
||||||
Handler: tcpMux,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
f.UpdateDomains(entries)
|
f.UpdateDomains(entries)
|
||||||
@@ -89,18 +108,33 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
|
|||||||
errCh := make(chan error, 2)
|
errCh := make(chan error, 2)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
log.Infof("DNS UDP listener running on %s", f.listenAddress)
|
log.Infof("DNS UDP listener running on %s", addrDesc)
|
||||||
errCh <- f.dnsServer.ListenAndServe()
|
errCh <- f.dnsServer.ActivateAndServe()
|
||||||
}()
|
}()
|
||||||
go func() {
|
go func() {
|
||||||
log.Infof("DNS TCP listener running on %s", f.listenAddress)
|
log.Infof("DNS TCP listener running on %s", addrDesc)
|
||||||
errCh <- f.tcpServer.ListenAndServe()
|
errCh <- f.tcpServer.ActivateAndServe()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// return the first error we get (e.g. bind failure or shutdown)
|
|
||||||
return <-errCh
|
return <-errCh
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *DNSForwarder) createUDPListener(netstackNet *netstack.Net) (net.PacketConn, error) {
|
||||||
|
if netstackNet != nil {
|
||||||
|
return netstackNet.ListenUDPAddrPort(f.listenAddress)
|
||||||
|
}
|
||||||
|
|
||||||
|
return net.ListenUDP("udp", net.UDPAddrFromAddrPort(f.listenAddress))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *DNSForwarder) createTCPListener(netstackNet *netstack.Net) (net.Listener, error) {
|
||||||
|
if netstackNet != nil {
|
||||||
|
return netstackNet.ListenTCPAddrPort(f.listenAddress)
|
||||||
|
}
|
||||||
|
|
||||||
|
return net.ListenTCP("tcp", net.TCPAddrFromAddrPort(f.listenAddress))
|
||||||
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
|
func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
|
||||||
f.mutex.Lock()
|
f.mutex.Lock()
|
||||||
defer f.mutex.Unlock()
|
defer f.mutex.Unlock()
|
||||||
|
|||||||
@@ -297,7 +297,7 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
|
|||||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.queryDomain)).Return([]netip.Addr{fakeIP}, nil)
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.queryDomain)).Return([]netip.Addr{fakeIP}, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
|
||||||
forwarder.resolver = mockResolver
|
forwarder.resolver = mockResolver
|
||||||
|
|
||||||
d, err := domain.FromString(tt.configuredDomain)
|
d, err := domain.FromString(tt.configuredDomain)
|
||||||
@@ -402,7 +402,7 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
|
|||||||
mockResolver := &MockResolver{}
|
mockResolver := &MockResolver{}
|
||||||
|
|
||||||
// Set up forwarder
|
// Set up forwarder
|
||||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
|
||||||
forwarder.resolver = mockResolver
|
forwarder.resolver = mockResolver
|
||||||
|
|
||||||
// Create entries and track sets
|
// Create entries and track sets
|
||||||
@@ -489,7 +489,7 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
|
|||||||
mockFirewall := &MockFirewall{}
|
mockFirewall := &MockFirewall{}
|
||||||
mockResolver := &MockResolver{}
|
mockResolver := &MockResolver{}
|
||||||
|
|
||||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
|
||||||
forwarder.resolver = mockResolver
|
forwarder.resolver = mockResolver
|
||||||
|
|
||||||
// Configure a single domain
|
// Configure a single domain
|
||||||
@@ -584,7 +584,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
|
||||||
|
|
||||||
d, err := domain.FromString(tt.configured)
|
d, err := domain.FromString(tt.configured)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -616,7 +616,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
|||||||
func TestDNSForwarder_TCPTruncation(t *testing.T) {
|
func TestDNSForwarder_TCPTruncation(t *testing.T) {
|
||||||
// Test that large UDP responses are truncated with TC bit set
|
// Test that large UDP responses are truncated with TC bit set
|
||||||
mockResolver := &MockResolver{}
|
mockResolver := &MockResolver{}
|
||||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
|
||||||
forwarder.resolver = mockResolver
|
forwarder.resolver = mockResolver
|
||||||
|
|
||||||
d, _ := domain.FromString("example.com")
|
d, _ := domain.FromString("example.com")
|
||||||
@@ -652,7 +652,7 @@ func TestDNSForwarder_TCPTruncation(t *testing.T) {
|
|||||||
// a subsequent upstream failure still returns a successful response from cache.
|
// a subsequent upstream failure still returns a successful response from cache.
|
||||||
func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
||||||
mockResolver := &MockResolver{}
|
mockResolver := &MockResolver{}
|
||||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
|
||||||
forwarder.resolver = mockResolver
|
forwarder.resolver = mockResolver
|
||||||
|
|
||||||
d, err := domain.FromString("example.com")
|
d, err := domain.FromString("example.com")
|
||||||
@@ -696,7 +696,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
|||||||
// Verifies that cache normalization works across casing and trailing dot variations.
|
// Verifies that cache normalization works across casing and trailing dot variations.
|
||||||
func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
||||||
mockResolver := &MockResolver{}
|
mockResolver := &MockResolver{}
|
||||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
|
||||||
forwarder.resolver = mockResolver
|
forwarder.resolver = mockResolver
|
||||||
|
|
||||||
d, err := domain.FromString("ExAmPlE.CoM")
|
d, err := domain.FromString("ExAmPlE.CoM")
|
||||||
@@ -742,7 +742,7 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
|||||||
mockFirewall := &MockFirewall{}
|
mockFirewall := &MockFirewall{}
|
||||||
mockResolver := &MockResolver{}
|
mockResolver := &MockResolver{}
|
||||||
|
|
||||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
|
||||||
forwarder.resolver = mockResolver
|
forwarder.resolver = mockResolver
|
||||||
|
|
||||||
// Set up complex overlapping patterns
|
// Set up complex overlapping patterns
|
||||||
@@ -804,7 +804,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
|
|||||||
mockFirewall := &MockFirewall{}
|
mockFirewall := &MockFirewall{}
|
||||||
mockResolver := &MockResolver{}
|
mockResolver := &MockResolver{}
|
||||||
|
|
||||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
|
||||||
forwarder.resolver = mockResolver
|
forwarder.resolver = mockResolver
|
||||||
|
|
||||||
d, err := domain.FromString("example.com")
|
d, err := domain.FromString("example.com")
|
||||||
@@ -925,7 +925,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
|
|||||||
|
|
||||||
func TestDNSForwarder_EmptyQuery(t *testing.T) {
|
func TestDNSForwarder_EmptyQuery(t *testing.T) {
|
||||||
// Test handling of malformed query with no questions
|
// Test handling of malformed query with no questions
|
||||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
|
||||||
|
|
||||||
query := &dns.Msg{}
|
query := &dns.Msg{}
|
||||||
// Don't set any question
|
// Don't set any question
|
||||||
|
|||||||
@@ -10,9 +10,11 @@ import (
|
|||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
@@ -24,6 +26,12 @@ const (
|
|||||||
envServerPort = "NB_DNS_FORWARDER_PORT"
|
envServerPort = "NB_DNS_FORWARDER_PORT"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// wgIface defines the interface for WireGuard interface operations needed by the DNS forwarder.
|
||||||
|
type wgIface interface {
|
||||||
|
GetNet() *netstack.Net
|
||||||
|
Address() wgaddr.Address
|
||||||
|
}
|
||||||
|
|
||||||
// ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list.
|
// ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list.
|
||||||
type ForwarderEntry struct {
|
type ForwarderEntry struct {
|
||||||
Domain domain.Domain
|
Domain domain.Domain
|
||||||
@@ -34,7 +42,7 @@ type ForwarderEntry struct {
|
|||||||
type Manager struct {
|
type Manager struct {
|
||||||
firewall firewall.Manager
|
firewall firewall.Manager
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
localAddr netip.Addr
|
wgIface wgIface
|
||||||
serverPort uint16
|
serverPort uint16
|
||||||
|
|
||||||
fwRules []firewall.Rule
|
fwRules []firewall.Rule
|
||||||
@@ -42,7 +50,7 @@ type Manager struct {
|
|||||||
dnsForwarder *DNSForwarder
|
dnsForwarder *DNSForwarder
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManager(fw firewall.Manager, statusRecorder *peer.Status, localAddr netip.Addr) *Manager {
|
func NewManager(fw firewall.Manager, statusRecorder *peer.Status, wgIface wgIface) *Manager {
|
||||||
serverPort := nbdns.ForwarderServerPort
|
serverPort := nbdns.ForwarderServerPort
|
||||||
if envPort := os.Getenv(envServerPort); envPort != "" {
|
if envPort := os.Getenv(envServerPort); envPort != "" {
|
||||||
if port, err := strconv.ParseUint(envPort, 10, 16); err == nil && port > 0 {
|
if port, err := strconv.ParseUint(envPort, 10, 16); err == nil && port > 0 {
|
||||||
@@ -56,7 +64,7 @@ func NewManager(fw firewall.Manager, statusRecorder *peer.Status, localAddr neti
|
|||||||
return &Manager{
|
return &Manager{
|
||||||
firewall: fw,
|
firewall: fw,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
localAddr: localAddr,
|
wgIface: wgIface,
|
||||||
serverPort: serverPort,
|
serverPort: serverPort,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -71,21 +79,25 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.localAddr.IsValid() && m.firewall != nil {
|
localAddr := m.wgIface.Address().IP
|
||||||
if err := m.firewall.AddInboundDNAT(m.localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
|
||||||
|
if localAddr.IsValid() && m.firewall != nil {
|
||||||
|
if err := m.firewall.AddInboundDNAT(localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||||
log.Warnf("failed to add DNS UDP DNAT rule: %v", err)
|
log.Warnf("failed to add DNS UDP DNAT rule: %v", err)
|
||||||
} else {
|
} else {
|
||||||
log.Infof("added DNS UDP DNAT rule: %s:%d -> %s:%d", m.localAddr, nbdns.ForwarderClientPort, m.localAddr, m.serverPort)
|
log.Infof("added DNS UDP DNAT rule: %s:%d -> %s:%d", localAddr, nbdns.ForwarderClientPort, localAddr, m.serverPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.firewall.AddInboundDNAT(m.localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
if err := m.firewall.AddInboundDNAT(localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||||
log.Warnf("failed to add DNS TCP DNAT rule: %v", err)
|
log.Warnf("failed to add DNS TCP DNAT rule: %v", err)
|
||||||
} else {
|
} else {
|
||||||
log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", m.localAddr, nbdns.ForwarderClientPort, m.localAddr, m.serverPort)
|
log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", localAddr, nbdns.ForwarderClientPort, localAddr, m.serverPort)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", m.serverPort), dnsTTL, m.firewall, m.statusRecorder)
|
listenAddress := netip.AddrPortFrom(localAddr, m.serverPort)
|
||||||
|
m.dnsForwarder = NewDNSForwarder(listenAddress, dnsTTL, m.firewall, m.statusRecorder, m.wgIface)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if err := m.dnsForwarder.Listen(fwdEntries); err != nil {
|
if err := m.dnsForwarder.Listen(fwdEntries); err != nil {
|
||||||
// todo handle close error if it is exists
|
// todo handle close error if it is exists
|
||||||
@@ -111,12 +123,13 @@ func (m *Manager) Stop(ctx context.Context) error {
|
|||||||
|
|
||||||
var mErr *multierror.Error
|
var mErr *multierror.Error
|
||||||
|
|
||||||
if m.localAddr.IsValid() && m.firewall != nil {
|
localAddr := m.wgIface.Address().IP
|
||||||
if err := m.firewall.RemoveInboundDNAT(m.localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
if localAddr.IsValid() && m.firewall != nil {
|
||||||
|
if err := m.firewall.RemoveInboundDNAT(localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||||
mErr = multierror.Append(mErr, fmt.Errorf("remove DNS UDP DNAT rule: %w", err))
|
mErr = multierror.Append(mErr, fmt.Errorf("remove DNS UDP DNAT rule: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.firewall.RemoveInboundDNAT(m.localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
if err := m.firewall.RemoveInboundDNAT(localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||||
mErr = multierror.Append(mErr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err))
|
mErr = multierror.Append(mErr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1855,35 +1855,69 @@ func (e *Engine) updateDNSForwarder(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !enabled {
|
if !enabled {
|
||||||
if e.dnsForwardMgr == nil {
|
e.stopDNSForwarder()
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
|
||||||
log.Errorf("failed to stop DNS forward: %v", err)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(fwdEntries) > 0 {
|
if len(fwdEntries) > 0 {
|
||||||
if e.dnsForwardMgr == nil {
|
if e.dnsForwardMgr == nil {
|
||||||
localAddr := e.wgInterface.Address().IP
|
e.startDNSForwarder(fwdEntries)
|
||||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, localAddr)
|
|
||||||
|
|
||||||
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
|
||||||
log.Errorf("failed to start DNS forward: %v", err)
|
|
||||||
e.dnsForwardMgr = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("started domain router service with %d entries", len(fwdEntries))
|
|
||||||
} else {
|
} else {
|
||||||
e.dnsForwardMgr.UpdateDomains(fwdEntries)
|
e.dnsForwardMgr.UpdateDomains(fwdEntries)
|
||||||
}
|
}
|
||||||
} else if e.dnsForwardMgr != nil {
|
} else if e.dnsForwardMgr != nil {
|
||||||
log.Infof("disable domain router service")
|
log.Infof("disable domain router service")
|
||||||
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
e.stopDNSForwarder()
|
||||||
log.Errorf("failed to stop DNS forward: %v", err)
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *Engine) startDNSForwarder(fwdEntries []*dnsfwd.ForwarderEntry) {
|
||||||
|
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, e.wgInterface)
|
||||||
|
e.registerDNSServices()
|
||||||
|
|
||||||
|
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
||||||
|
log.Errorf("failed to start DNS forward: %v", err)
|
||||||
e.dnsForwardMgr = nil
|
e.dnsForwardMgr = nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("started domain router service with %d entries", len(fwdEntries))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) stopDNSForwarder() {
|
||||||
|
if e.dnsForwardMgr == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||||
|
log.Errorf("failed to stop DNS forward: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.unregisterDNSServices()
|
||||||
|
e.dnsForwardMgr = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) registerDNSServices() {
|
||||||
|
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||||
|
if registrar, ok := e.firewall.(interface {
|
||||||
|
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||||
|
}); ok {
|
||||||
|
registrar.RegisterNetstackService(nftypes.UDP, nbdns.ForwarderServerPort)
|
||||||
|
registrar.RegisterNetstackService(nftypes.TCP, nbdns.ForwarderServerPort)
|
||||||
|
log.Debugf("registered DNS forwarder service with netstack for UDP/TCP:%d", nbdns.ForwarderServerPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) unregisterDNSServices() {
|
||||||
|
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||||
|
if registrar, ok := e.firewall.(interface {
|
||||||
|
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||||
|
}); ok {
|
||||||
|
registrar.UnregisterNetstackService(nftypes.UDP, nbdns.ForwarderServerPort)
|
||||||
|
registrar.UnregisterNetstackService(nftypes.TCP, nbdns.ForwarderServerPort)
|
||||||
|
log.Debugf("unregistered DNS forwarder service with netstack for UDP/TCP:%d", nbdns.ForwarderServerPort)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user