mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
225 lines
6.4 KiB
Go
225 lines
6.4 KiB
Go
package dnsfwd
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"os"
|
|
"strconv"
|
|
|
|
"github.com/hashicorp/go-multierror"
|
|
log "github.com/sirupsen/logrus"
|
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
|
|
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
|
"github.com/netbirdio/netbird/client/internal/peer"
|
|
nbdns "github.com/netbirdio/netbird/dns"
|
|
"github.com/netbirdio/netbird/route"
|
|
"github.com/netbirdio/netbird/shared/management/domain"
|
|
)
|
|
|
|
const (
|
|
dnsTTL = 60
|
|
envServerPort = "NB_DNS_FORWARDER_PORT"
|
|
)
|
|
|
|
// wgIface defines the interface for WireGuard interface operations needed by the DNS forwarder.
|
|
type wgIface interface {
|
|
GetNet() *netstack.Net
|
|
Address() wgaddr.Address
|
|
}
|
|
|
|
// ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list.
|
|
type ForwarderEntry struct {
|
|
Domain domain.Domain
|
|
ResID route.ResID
|
|
Set firewall.Set
|
|
}
|
|
|
|
type Manager struct {
|
|
firewall firewall.Manager
|
|
statusRecorder *peer.Status
|
|
wgIface wgIface
|
|
serverPort uint16
|
|
|
|
fwRules []firewall.Rule
|
|
tcpRules []firewall.Rule
|
|
dnsForwarder *DNSForwarder
|
|
}
|
|
|
|
func NewManager(fw firewall.Manager, statusRecorder *peer.Status, wgIface wgIface) *Manager {
|
|
serverPort := nbdns.ForwarderServerPort
|
|
if envPort := os.Getenv(envServerPort); envPort != "" {
|
|
if port, err := strconv.ParseUint(envPort, 10, 16); err == nil && port > 0 {
|
|
serverPort = uint16(port)
|
|
log.Infof("using custom DNS forwarder port from %s: %d", envServerPort, serverPort)
|
|
} else {
|
|
log.Warnf("invalid %s value %q, using default %d", envServerPort, envPort, nbdns.ForwarderServerPort)
|
|
}
|
|
}
|
|
|
|
return &Manager{
|
|
firewall: fw,
|
|
statusRecorder: statusRecorder,
|
|
wgIface: wgIface,
|
|
serverPort: serverPort,
|
|
}
|
|
}
|
|
|
|
func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
|
|
log.Infof("starting DNS forwarder")
|
|
if m.dnsForwarder != nil {
|
|
return nil
|
|
}
|
|
|
|
if err := m.allowDNSFirewall(); err != nil {
|
|
return err
|
|
}
|
|
|
|
localAddr := m.wgIface.Address().IP
|
|
|
|
if localAddr.IsValid() && m.firewall != nil {
|
|
if err := m.firewall.AddInboundDNAT(localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
|
log.Warnf("failed to add DNS UDP DNAT rule: %v", err)
|
|
} else {
|
|
log.Infof("added DNS UDP DNAT rule: %s:%d -> %s:%d", localAddr, nbdns.ForwarderClientPort, localAddr, m.serverPort)
|
|
}
|
|
|
|
if err := m.firewall.AddInboundDNAT(localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
|
log.Warnf("failed to add DNS TCP DNAT rule: %v", err)
|
|
} else {
|
|
log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", localAddr, nbdns.ForwarderClientPort, localAddr, m.serverPort)
|
|
}
|
|
}
|
|
|
|
listenAddress := netip.AddrPortFrom(localAddr, m.serverPort)
|
|
m.dnsForwarder = NewDNSForwarder(listenAddress, dnsTTL, m.firewall, m.statusRecorder, m.wgIface)
|
|
|
|
go func() {
|
|
if err := m.dnsForwarder.Listen(fwdEntries); err != nil {
|
|
// todo handle close error if it is exists
|
|
log.Errorf("failed to start DNS forwarder, err: %v", err)
|
|
}
|
|
}()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *Manager) UpdateDomains(entries []*ForwarderEntry) {
|
|
if m.dnsForwarder == nil {
|
|
return
|
|
}
|
|
|
|
m.dnsForwarder.UpdateDomains(entries)
|
|
}
|
|
|
|
func (m *Manager) Stop(ctx context.Context) error {
|
|
if m.dnsForwarder == nil {
|
|
return nil
|
|
}
|
|
|
|
var mErr *multierror.Error
|
|
|
|
localAddr := m.wgIface.Address().IP
|
|
if localAddr.IsValid() && m.firewall != nil {
|
|
if err := m.firewall.RemoveInboundDNAT(localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
|
mErr = multierror.Append(mErr, fmt.Errorf("remove DNS UDP DNAT rule: %w", err))
|
|
}
|
|
|
|
if err := m.firewall.RemoveInboundDNAT(localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
|
mErr = multierror.Append(mErr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err))
|
|
}
|
|
}
|
|
|
|
m.unregisterNetstackServices()
|
|
|
|
if err := m.dropDNSFirewall(); err != nil {
|
|
mErr = multierror.Append(mErr, err)
|
|
}
|
|
|
|
if err := m.dnsForwarder.Close(ctx); err != nil {
|
|
mErr = multierror.Append(mErr, err)
|
|
}
|
|
|
|
m.dnsForwarder = nil
|
|
return nberrors.FormatErrorOrNil(mErr)
|
|
}
|
|
|
|
func (m *Manager) allowDNSFirewall() error {
|
|
dport := &firewall.Port{
|
|
IsRange: false,
|
|
Values: []uint16{m.serverPort},
|
|
}
|
|
|
|
if m.firewall == nil {
|
|
return nil
|
|
}
|
|
|
|
dnsRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "")
|
|
if err != nil {
|
|
return fmt.Errorf("add udp firewall rule: %w", err)
|
|
}
|
|
|
|
tcpRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept, "")
|
|
if err != nil {
|
|
return fmt.Errorf("add tcp firewall rule: %w", err)
|
|
}
|
|
|
|
if err := m.firewall.Flush(); err != nil {
|
|
return fmt.Errorf("flush: %w", err)
|
|
}
|
|
|
|
m.fwRules = dnsRules
|
|
m.tcpRules = tcpRules
|
|
|
|
m.registerNetstackServices()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *Manager) registerNetstackServices() {
|
|
if netstackNet := m.wgIface.GetNet(); netstackNet != nil {
|
|
if registrar, ok := m.firewall.(interface {
|
|
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
|
|
}); ok {
|
|
registrar.RegisterNetstackService(nftypes.TCP, m.serverPort)
|
|
registrar.RegisterNetstackService(nftypes.UDP, m.serverPort)
|
|
log.Debugf("registered DNS forwarder service with netstack for UDP/TCP:%d", m.serverPort)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (m *Manager) unregisterNetstackServices() {
|
|
if netstackNet := m.wgIface.GetNet(); netstackNet != nil {
|
|
if registrar, ok := m.firewall.(interface {
|
|
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
|
|
}); ok {
|
|
registrar.UnregisterNetstackService(nftypes.TCP, m.serverPort)
|
|
registrar.UnregisterNetstackService(nftypes.UDP, m.serverPort)
|
|
log.Debugf("unregistered DNS forwarder service with netstack for UDP/TCP:%d", m.serverPort)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (m *Manager) dropDNSFirewall() error {
|
|
var mErr *multierror.Error
|
|
for _, rule := range m.fwRules {
|
|
if err := m.firewall.DeletePeerRule(rule); err != nil {
|
|
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
|
|
}
|
|
}
|
|
for _, rule := range m.tcpRules {
|
|
if err := m.firewall.DeletePeerRule(rule); err != nil {
|
|
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
|
|
}
|
|
}
|
|
|
|
m.fwRules = nil
|
|
m.tcpRules = nil
|
|
return nberrors.FormatErrorOrNil(mErr)
|
|
}
|