[client, management] Add new network concept (#3047)

---------

Co-authored-by: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com>
Co-authored-by: bcmmbaga <bethuelmbaga12@gmail.com>
Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
This commit is contained in:
Viktor Liu
2024-12-20 11:30:28 +01:00
committed by GitHub
parent 37ad370344
commit ddc365f7a0
155 changed files with 13909 additions and 4993 deletions

View File

@@ -0,0 +1,157 @@
package dnsfwd
import (
"context"
"errors"
"net"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
)
const errResolveFailed = "failed to resolve query for domain=%s: %v"
type DNSForwarder struct {
listenAddress string
ttl uint32
domains []string
dnsServer *dns.Server
mux *dns.ServeMux
}
func NewDNSForwarder(listenAddress string, ttl uint32) *DNSForwarder {
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
return &DNSForwarder{
listenAddress: listenAddress,
ttl: ttl,
}
}
func (f *DNSForwarder) Listen(domains []string) error {
log.Infof("listen DNS forwarder on address=%s", f.listenAddress)
mux := dns.NewServeMux()
dnsServer := &dns.Server{
Addr: f.listenAddress,
Net: "udp",
Handler: mux,
}
f.dnsServer = dnsServer
f.mux = mux
f.UpdateDomains(domains)
return dnsServer.ListenAndServe()
}
func (f *DNSForwarder) UpdateDomains(domains []string) {
log.Debugf("Updating domains from %v to %v", f.domains, domains)
for _, d := range f.domains {
f.mux.HandleRemove(d)
}
newDomains := filterDomains(domains)
for _, d := range newDomains {
f.mux.HandleFunc(d, f.handleDNSQuery)
}
f.domains = newDomains
}
func (f *DNSForwarder) Close(ctx context.Context) error {
if f.dnsServer == nil {
return nil
}
return f.dnsServer.ShutdownContext(ctx)
}
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
if len(query.Question) == 0 {
return
}
log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v",
query.Question[0].Name, query.Question[0].Qtype, query.Question[0].Qclass)
question := query.Question[0]
domain := question.Name
resp := query.SetReply(query)
ips, err := net.LookupIP(domain)
if err != nil {
var dnsErr *net.DNSError
switch {
case errors.As(err, &dnsErr):
resp.Rcode = dns.RcodeServerFailure
if dnsErr.IsNotFound {
// Pass through NXDOMAIN
resp.Rcode = dns.RcodeNameError
}
if dnsErr.Server != "" {
log.Warnf("failed to resolve query for domain=%s server=%s: %v", domain, dnsErr.Server, err)
} else {
log.Warnf(errResolveFailed, domain, err)
}
default:
resp.Rcode = dns.RcodeServerFailure
log.Warnf(errResolveFailed, domain, err)
}
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write failure DNS response: %v", err)
}
return
}
for _, ip := range ips {
var respRecord dns.RR
if ip.To4() == nil {
log.Tracef("resolved domain=%s to IPv6=%s", domain, ip)
rr := dns.AAAA{
AAAA: ip,
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: f.ttl,
},
}
respRecord = &rr
} else {
log.Tracef("resolved domain=%s to IPv4=%s", domain, ip)
rr := dns.A{
A: ip,
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: f.ttl,
},
}
respRecord = &rr
}
resp.Answer = append(resp.Answer, respRecord)
}
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
}
}
// filterDomains returns a list of normalized domains
func filterDomains(domains []string) []string {
newDomains := make([]string, 0, len(domains))
for _, d := range domains {
if d == "" {
log.Warn("empty domain in DNS forwarder")
continue
}
newDomains = append(newDomains, nbdns.NormalizeZone(d))
}
return newDomains
}

View File

@@ -0,0 +1,106 @@
package dnsfwd
import (
"context"
"fmt"
"net"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
const (
// ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also
ListenPort = 5353
dnsTTL = 60 //seconds
)
type Manager struct {
firewall firewall.Manager
fwRules []firewall.Rule
dnsForwarder *DNSForwarder
}
func NewManager(fw firewall.Manager) *Manager {
return &Manager{
firewall: fw,
}
}
func (m *Manager) Start(domains []string) error {
log.Infof("starting DNS forwarder")
if m.dnsForwarder != nil {
return nil
}
if err := m.allowDNSFirewall(); err != nil {
return err
}
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL)
go func() {
if err := m.dnsForwarder.Listen(domains); err != nil {
// todo handle close error if it is exists
log.Errorf("failed to start DNS forwarder, err: %v", err)
}
}()
return nil
}
func (m *Manager) UpdateDomains(domains []string) {
if m.dnsForwarder == nil {
return
}
m.dnsForwarder.UpdateDomains(domains)
}
func (m *Manager) Stop(ctx context.Context) error {
if m.dnsForwarder == nil {
return nil
}
var mErr *multierror.Error
if err := m.dropDNSFirewall(); err != nil {
mErr = multierror.Append(mErr, err)
}
if err := m.dnsForwarder.Close(ctx); err != nil {
mErr = multierror.Append(mErr, err)
}
m.dnsForwarder = nil
return nberrors.FormatErrorOrNil(mErr)
}
func (h *Manager) allowDNSFirewall() error {
dport := &firewall.Port{
IsRange: false,
Values: []int{ListenPort},
}
dnsRules, err := h.firewall.AddPeerFiltering(net.ParseIP("0.0.0.0"), firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "")
if err != nil {
log.Errorf("failed to add allow DNS router rules, err: %v", err)
return err
}
h.fwRules = dnsRules
return nil
}
func (h *Manager) dropDNSFirewall() error {
var mErr *multierror.Error
for _, rule := range h.fwRules {
if err := h.firewall.DeletePeerRule(rule); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
}
}
h.fwRules = nil
return nberrors.FormatErrorOrNil(mErr)
}