mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
Feature/dns-server (#537)
Adding DNS server for client Updated the API with new fields Added custom zone object for peer's DNS resolution
This commit is contained in:
56
client/internal/dns/local.go
Normal file
56
client/internal/dns/local.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"github.com/miekg/dns"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type localResolver struct {
|
||||
registeredMap registrationMap
|
||||
records sync.Map
|
||||
}
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
log.Tracef("received question: %#v\n", r.Question[0])
|
||||
response := d.lookupRecord(r)
|
||||
if response == nil {
|
||||
log.Debugf("got empty response for question: %#v\n", r.Question[0])
|
||||
return
|
||||
}
|
||||
|
||||
replyMessage := &dns.Msg{}
|
||||
replyMessage.SetReply(r)
|
||||
replyMessage.Answer = append(replyMessage.Answer, response)
|
||||
|
||||
err := w.WriteMsg(replyMessage)
|
||||
if err != nil {
|
||||
log.Debugf("got an error while writing the local resolver response, error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *localResolver) lookupRecord(r *dns.Msg) dns.RR {
|
||||
record, found := d.records.Load(r.Question[0].Name)
|
||||
if !found {
|
||||
return nil
|
||||
}
|
||||
|
||||
return record.(dns.RR)
|
||||
}
|
||||
|
||||
func (d *localResolver) registerRecord(record nbdns.SimpleRecord) error {
|
||||
fullRecord, err := dns.NewRR(record.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.records.Store(fullRecord.Header().Name, fullRecord)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *localResolver) deleteRecord(recordKey string) {
|
||||
d.records.Delete(dns.Fqdn(recordKey))
|
||||
}
|
||||
86
client/internal/dns/local_test.go
Normal file
86
client/internal/dns/local_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"github.com/miekg/dns"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLocalResolver_ServeDNS(t *testing.T) {
|
||||
recordA := nbdns.SimpleRecord{
|
||||
Name: "peera.netbird.cloud.",
|
||||
Type: 1,
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "1.2.3.4",
|
||||
}
|
||||
|
||||
recordCNAME := nbdns.SimpleRecord{
|
||||
Name: "peerb.netbird.cloud.",
|
||||
Type: 5,
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "www.netbird.io",
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputRecord nbdns.SimpleRecord
|
||||
inputMSG *dns.Msg
|
||||
responseShouldBeNil bool
|
||||
}{
|
||||
{
|
||||
name: "Should Resolve A Record",
|
||||
inputRecord: recordA,
|
||||
inputMSG: new(dns.Msg).SetQuestion(recordA.Name, dns.TypeA),
|
||||
},
|
||||
{
|
||||
name: "Should Resolve CNAME Record",
|
||||
inputRecord: recordCNAME,
|
||||
inputMSG: new(dns.Msg).SetQuestion(recordCNAME.Name, dns.TypeCNAME),
|
||||
},
|
||||
{
|
||||
name: "Should Not Write When Not Found A Record",
|
||||
inputRecord: recordA,
|
||||
inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA),
|
||||
responseShouldBeNil: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
resolver := &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
}
|
||||
_ = resolver.registerRecord(testCase.inputRecord)
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &mockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
resolver.ServeDNS(responseWriter, testCase.inputMSG)
|
||||
|
||||
if responseMSG == nil {
|
||||
if testCase.responseShouldBeNil {
|
||||
return
|
||||
}
|
||||
t.Fatalf("should write a response message")
|
||||
}
|
||||
|
||||
answerString := responseMSG.Answer[0].String()
|
||||
if !strings.Contains(answerString, testCase.inputRecord.Name) {
|
||||
t.Fatalf("answer doesn't contain the same domain name: \nWant: %s\nGot:%s", testCase.name, answerString)
|
||||
}
|
||||
if !strings.Contains(answerString, dns.Type(testCase.inputRecord.Type).String()) {
|
||||
t.Fatalf("answer doesn't contain the correct type: \nWant: %s\nGot:%s", dns.Type(testCase.inputRecord.Type).String(), answerString)
|
||||
}
|
||||
if !strings.Contains(answerString, testCase.inputRecord.RData) {
|
||||
t.Fatalf("answer doesn't contain the same address: \nWant: %s\nGot:%s", testCase.inputRecord.RData, answerString)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
25
client/internal/dns/mock_test.go
Normal file
25
client/internal/dns/mock_test.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"github.com/miekg/dns"
|
||||
"net"
|
||||
)
|
||||
|
||||
type mockResponseWriter struct {
|
||||
WriteMsgFunc func(m *dns.Msg) error
|
||||
}
|
||||
|
||||
func (rw *mockResponseWriter) WriteMsg(m *dns.Msg) error {
|
||||
if rw.WriteMsgFunc != nil {
|
||||
return rw.WriteMsgFunc(m)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rw *mockResponseWriter) LocalAddr() net.Addr { return nil }
|
||||
func (rw *mockResponseWriter) RemoteAddr() net.Addr { return nil }
|
||||
func (rw *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
||||
func (rw *mockResponseWriter) Close() error { return nil }
|
||||
func (rw *mockResponseWriter) TsigStatus() error { return nil }
|
||||
func (rw *mockResponseWriter) TsigTimersOnly(bool) {}
|
||||
func (rw *mockResponseWriter) Hijack() {}
|
||||
270
client/internal/dns/server.go
Normal file
270
client/internal/dns/server.go
Normal file
@@ -0,0 +1,270 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/miekg/dns"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
port = 5053
|
||||
defaultIP = "0.0.0.0"
|
||||
)
|
||||
|
||||
// Server dns server object
|
||||
type Server struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
mux sync.Mutex
|
||||
server *dns.Server
|
||||
dnsMux *dns.ServeMux
|
||||
dnsMuxMap registrationMap
|
||||
localResolver *localResolver
|
||||
updateSerial uint64
|
||||
listenerIsRunning bool
|
||||
}
|
||||
|
||||
type registrationMap map[string]struct{}
|
||||
|
||||
type muxUpdate struct {
|
||||
domain string
|
||||
handler dns.Handler
|
||||
}
|
||||
|
||||
// NewServer returns a new dns server
|
||||
func NewServer(ctx context.Context) *Server {
|
||||
mux := dns.NewServeMux()
|
||||
|
||||
dnsServer := &dns.Server{
|
||||
Addr: fmt.Sprintf("%s:%d", defaultIP, port),
|
||||
Net: "udp",
|
||||
Handler: mux,
|
||||
UDPSize: 65535,
|
||||
}
|
||||
|
||||
ctx, stop := context.WithCancel(ctx)
|
||||
|
||||
return &Server{
|
||||
ctx: ctx,
|
||||
stop: stop,
|
||||
server: dnsServer,
|
||||
dnsMux: mux,
|
||||
dnsMuxMap: make(registrationMap),
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Start runs the listener in a go routine
|
||||
func (s *Server) Start() {
|
||||
log.Debugf("starting dns on %s:%d", defaultIP, port)
|
||||
go func() {
|
||||
s.setListenerStatus(true)
|
||||
defer s.setListenerStatus(false)
|
||||
err := s.server.ListenAndServe()
|
||||
if err != nil {
|
||||
log.Errorf("dns server returned an error: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Server) setListenerStatus(running bool) {
|
||||
s.listenerIsRunning = running
|
||||
}
|
||||
|
||||
// Stop stops the server
|
||||
func (s *Server) Stop() {
|
||||
s.stop()
|
||||
|
||||
err := s.stopListener()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) stopListener() error {
|
||||
if !s.listenerIsRunning {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := s.server.ShutdownContext(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("stopping dns server listener returned an error: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateDNSServer processes an update received from the management service
|
||||
func (s *Server) UpdateDNSServer(serial uint64, update nbdns.Update) error {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
log.Infof("not updating DNS server as context is closed")
|
||||
return s.ctx.Err()
|
||||
default:
|
||||
if serial < s.updateSerial {
|
||||
return fmt.Errorf("not applying dns update, error: "+
|
||||
"network update is %d behind the last applied update", s.updateSerial-serial)
|
||||
}
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
// is the service should be disabled, we stop the listener
|
||||
// and proceed with a regular update to clean up the handlers and records
|
||||
if !update.ServiceEnable {
|
||||
err := s.stopListener()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
} else if !s.listenerIsRunning {
|
||||
s.Start()
|
||||
}
|
||||
|
||||
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||
if err != nil {
|
||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||
}
|
||||
upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups)
|
||||
if err != nil {
|
||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||
}
|
||||
|
||||
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
|
||||
|
||||
s.updateMux(muxUpdates)
|
||||
s.updateLocalResolver(localRecords)
|
||||
|
||||
s.updateSerial = serial
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) {
|
||||
var muxUpdates []muxUpdate
|
||||
localRecords := make(map[string]nbdns.SimpleRecord, 0)
|
||||
|
||||
for _, customZone := range customZones {
|
||||
|
||||
if len(customZone.Records) == 0 {
|
||||
return nil, nil, fmt.Errorf("received an empty list of records")
|
||||
}
|
||||
|
||||
muxUpdates = append(muxUpdates, muxUpdate{
|
||||
domain: customZone.Domain,
|
||||
handler: s.localResolver,
|
||||
})
|
||||
|
||||
for _, record := range customZone.Records {
|
||||
localRecords[record.Name] = record
|
||||
}
|
||||
}
|
||||
return muxUpdates, localRecords, nil
|
||||
}
|
||||
|
||||
func (s *Server) buildUpstreamHandlerUpdate(nameServerGroups []nbdns.NameServerGroup) ([]muxUpdate, error) {
|
||||
var muxUpdates []muxUpdate
|
||||
for _, nsGroup := range nameServerGroups {
|
||||
if len(nsGroup.NameServers) == 0 {
|
||||
return nil, fmt.Errorf("received a nameserver group with empty nameserver list")
|
||||
}
|
||||
handler := &upstreamResolver{
|
||||
parentCTX: s.ctx,
|
||||
upstreamClient: &dns.Client{},
|
||||
upstreamTimeout: defaultUpstreamTimeout,
|
||||
}
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
if ns.NSType != nbdns.UDPNameServerType {
|
||||
log.Warnf("skiping nameserver %s with type %s, this peer supports only %s",
|
||||
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
|
||||
continue
|
||||
}
|
||||
handler.upstreamServers = append(handler.upstreamServers, getNSHostPort(ns))
|
||||
}
|
||||
|
||||
if len(handler.upstreamServers) == 0 {
|
||||
log.Errorf("received a nameserver group with an invalid nameserver list")
|
||||
continue
|
||||
}
|
||||
|
||||
if nsGroup.Primary {
|
||||
muxUpdates = append(muxUpdates, muxUpdate{
|
||||
domain: nbdns.RootZone,
|
||||
handler: handler,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if len(nsGroup.Domains) == 0 {
|
||||
return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list")
|
||||
}
|
||||
|
||||
for _, domain := range nsGroup.Domains {
|
||||
if domain == "" {
|
||||
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
|
||||
}
|
||||
muxUpdates = append(muxUpdates, muxUpdate{
|
||||
domain: domain,
|
||||
handler: handler,
|
||||
})
|
||||
}
|
||||
}
|
||||
return muxUpdates, nil
|
||||
}
|
||||
|
||||
func (s *Server) updateMux(muxUpdates []muxUpdate) {
|
||||
muxUpdateMap := make(registrationMap)
|
||||
|
||||
for _, update := range muxUpdates {
|
||||
s.registerMux(update.domain, update.handler)
|
||||
muxUpdateMap[update.domain] = struct{}{}
|
||||
}
|
||||
|
||||
for key := range s.dnsMuxMap {
|
||||
_, found := muxUpdateMap[key]
|
||||
if !found {
|
||||
s.deregisterMux(key)
|
||||
}
|
||||
}
|
||||
|
||||
s.dnsMuxMap = muxUpdateMap
|
||||
}
|
||||
|
||||
func (s *Server) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
|
||||
for key := range s.localResolver.registeredMap {
|
||||
_, found := update[key]
|
||||
if !found {
|
||||
s.localResolver.deleteRecord(key)
|
||||
}
|
||||
}
|
||||
|
||||
updatedMap := make(registrationMap)
|
||||
for key, record := range update {
|
||||
err := s.localResolver.registerRecord(record)
|
||||
if err != nil {
|
||||
log.Warnf("got an error while registering the record (%s), error: %v", record.String(), err)
|
||||
}
|
||||
updatedMap[key] = struct{}{}
|
||||
}
|
||||
|
||||
s.localResolver.registeredMap = updatedMap
|
||||
}
|
||||
|
||||
func getNSHostPort(ns nbdns.NameServer) string {
|
||||
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
|
||||
}
|
||||
|
||||
func (s *Server) registerMux(pattern string, handler dns.Handler) {
|
||||
s.dnsMux.Handle(pattern, handler)
|
||||
}
|
||||
|
||||
func (s *Server) deregisterMux(pattern string) {
|
||||
s.dnsMux.HandleRemove(pattern)
|
||||
}
|
||||
285
client/internal/dns/server_test.go
Normal file
285
client/internal/dns/server_test.go
Normal file
@@ -0,0 +1,285 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var zoneRecords = []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: "peera.netbird.cloud",
|
||||
Type: 1,
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "1.2.3.4",
|
||||
},
|
||||
}
|
||||
|
||||
func TestUpdateDNSServer(t *testing.T) {
|
||||
|
||||
nameServers := []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.4.4"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
initUpstreamMap registrationMap
|
||||
initLocalMap registrationMap
|
||||
initSerial uint64
|
||||
inputSerial uint64
|
||||
inputUpdate nbdns.Update
|
||||
shouldFail bool
|
||||
expectedUpstreamMap registrationMap
|
||||
expectedLocalMap registrationMap
|
||||
}{
|
||||
{
|
||||
name: "Initial Update Should Succeed",
|
||||
initLocalMap: make(registrationMap),
|
||||
initUpstreamMap: make(registrationMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Update{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"netbird.io"},
|
||||
NameServers: nameServers,
|
||||
},
|
||||
{
|
||||
NameServers: nameServers,
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: registrationMap{"netbird.io": struct{}{}, "netbird.cloud": struct{}{}, nbdns.RootZone: struct{}{}},
|
||||
expectedLocalMap: registrationMap{zoneRecords[0].Name: struct{}{}},
|
||||
},
|
||||
{
|
||||
name: "New Update Should Succeed",
|
||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
||||
initUpstreamMap: registrationMap{zoneRecords[0].Name: struct{}{}},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Update{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"netbird.io"},
|
||||
NameServers: nameServers,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: registrationMap{"netbird.io": struct{}{}, "netbird.cloud": struct{}{}},
|
||||
expectedLocalMap: registrationMap{zoneRecords[0].Name: struct{}{}},
|
||||
},
|
||||
{
|
||||
name: "Smaller Update Serial Should Be Skipped",
|
||||
initLocalMap: make(registrationMap),
|
||||
initUpstreamMap: make(registrationMap),
|
||||
initSerial: 2,
|
||||
inputSerial: 1,
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
||||
initLocalMap: make(registrationMap),
|
||||
initUpstreamMap: make(registrationMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Update{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: nameServers,
|
||||
},
|
||||
},
|
||||
},
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid NS Group Nameservers list Should Fail",
|
||||
initLocalMap: make(registrationMap),
|
||||
initUpstreamMap: make(registrationMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Update{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: nameServers,
|
||||
},
|
||||
},
|
||||
},
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid Custom Zone Records list Should Fail",
|
||||
initLocalMap: make(registrationMap),
|
||||
initUpstreamMap: make(registrationMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Update{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
},
|
||||
},
|
||||
NameServerGroups: []nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: nameServers,
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Empty Update Should Succeed and Clean Maps",
|
||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
||||
initUpstreamMap: registrationMap{zoneRecords[0].Name: struct{}{}},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Update{ServiceEnable: true},
|
||||
expectedUpstreamMap: make(registrationMap),
|
||||
expectedLocalMap: make(registrationMap),
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
dnsServer := NewServer(ctx)
|
||||
|
||||
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
||||
dnsServer.localResolver.registeredMap = testCase.initLocalMap
|
||||
dnsServer.updateSerial = testCase.initSerial
|
||||
dnsServer.listenerIsRunning = true
|
||||
|
||||
err := dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
||||
if err != nil {
|
||||
if testCase.shouldFail {
|
||||
return
|
||||
}
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
}
|
||||
|
||||
if len(dnsServer.dnsMuxMap) != len(testCase.expectedUpstreamMap) {
|
||||
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxMap))
|
||||
}
|
||||
|
||||
for key := range testCase.expectedUpstreamMap {
|
||||
_, found := dnsServer.dnsMuxMap[key]
|
||||
if !found {
|
||||
t.Fatalf("update upstream failed, key %s was not found in the dnsMuxMap: %#v", key, dnsServer.dnsMuxMap)
|
||||
}
|
||||
}
|
||||
|
||||
if len(dnsServer.localResolver.registeredMap) != len(testCase.expectedLocalMap) {
|
||||
t.Fatalf("update local failed, registered map size is different than expected, want %d, got %d", len(testCase.expectedLocalMap), len(dnsServer.localResolver.registeredMap))
|
||||
}
|
||||
|
||||
for key := range testCase.expectedLocalMap {
|
||||
_, found := dnsServer.localResolver.registeredMap[key]
|
||||
if !found {
|
||||
t.Fatalf("update local failed, key %s was not found in the localResolver.registeredMap: %#v", key, dnsServer.localResolver.registeredMap)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSServerStartStop(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
dnsServer := NewServer(ctx)
|
||||
if runtime.GOOS == "windows" && os.Getenv("CI") == "true" {
|
||||
// todo review why this test is not working only on github actions workflows
|
||||
t.Skip("skipping test in Windows CI workflows.")
|
||||
}
|
||||
|
||||
dnsServer.Start()
|
||||
|
||||
err := dnsServer.localResolver.registerRecord(zoneRecords[0])
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
dnsServer.dnsMux.Handle("netbird.cloud", dnsServer.localResolver)
|
||||
|
||||
resolver := &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
d := net.Dialer{
|
||||
Timeout: time.Second * 5,
|
||||
}
|
||||
addr := fmt.Sprintf("127.0.0.1:%d", port)
|
||||
conn, err := d.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
// retry test before exit, for slower systems
|
||||
return d.DialContext(ctx, network, addr)
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
},
|
||||
}
|
||||
|
||||
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to the server, error: %v", err)
|
||||
}
|
||||
|
||||
t.Log(ips)
|
||||
|
||||
if ips[0] != zoneRecords[0].RData {
|
||||
t.Fatalf("got a different IP from the server: want %s, got %s", zoneRecords[0].RData, ips[0])
|
||||
}
|
||||
|
||||
dnsServer.Stop()
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second*1)
|
||||
defer cancel()
|
||||
_, err = resolver.LookupHost(ctx, zoneRecords[0].Name)
|
||||
if err == nil {
|
||||
t.Fatalf("we should encounter an error when querying a stopped server")
|
||||
}
|
||||
}
|
||||
67
client/internal/dns/upstream.go
Normal file
67
client/internal/dns/upstream.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
const defaultUpstreamTimeout = 15 * time.Second
|
||||
|
||||
type upstreamResolver struct {
|
||||
parentCTX context.Context
|
||||
upstreamClient *dns.Client
|
||||
upstreamServers []string
|
||||
upstreamTimeout time.Duration
|
||||
}
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
log.Tracef("received an upstream question: %#v", r.Question[0])
|
||||
|
||||
select {
|
||||
case <-u.parentCTX.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
for _, upstream := range u.upstreamServers {
|
||||
ctx, cancel := context.WithTimeout(u.parentCTX, u.upstreamTimeout)
|
||||
rm, t, err := u.upstreamClient.ExchangeContext(ctx, r, upstream)
|
||||
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
if err == context.DeadlineExceeded || isTimeout(err) {
|
||||
log.Warnf("got an error while connecting to upstream %s, error: %v", upstream, err)
|
||||
continue
|
||||
}
|
||||
log.Errorf("got an error while querying the upstream %s, error: %v", upstream, err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Tracef("took %s to query the upstream %s", t, upstream)
|
||||
|
||||
err = w.WriteMsg(rm)
|
||||
if err != nil {
|
||||
log.Errorf("got an error while writing the upstream resolver response, error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
log.Errorf("all queries to the upstream nameservers failed with timeout")
|
||||
}
|
||||
|
||||
// isTimeout returns true if the given error is a network timeout error.
|
||||
//
|
||||
// Copied from k8s.io/apimachinery/pkg/util/net.IsTimeout
|
||||
func isTimeout(err error) bool {
|
||||
var neterr net.Error
|
||||
if errors.As(err, &neterr) {
|
||||
return neterr != nil && neterr.Timeout()
|
||||
}
|
||||
return false
|
||||
}
|
||||
110
client/internal/dns/upstream_test.go
Normal file
110
client/internal/dns/upstream_test.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/miekg/dns"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputMSG *dns.Msg
|
||||
responseShouldBeNil bool
|
||||
InputServers []string
|
||||
timeout time.Duration
|
||||
cancelCTX bool
|
||||
expectedAnswer string
|
||||
}{
|
||||
{
|
||||
name: "Should Resolve A Record",
|
||||
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
||||
InputServers: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
||||
timeout: defaultUpstreamTimeout,
|
||||
expectedAnswer: "1.1.1.1",
|
||||
},
|
||||
{
|
||||
name: "Should Resolve If First Upstream Times Out",
|
||||
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
||||
InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"},
|
||||
timeout: 2 * time.Second,
|
||||
expectedAnswer: "1.1.1.1",
|
||||
},
|
||||
{
|
||||
name: "Should Not Resolve If Can't Connect To Both Servers",
|
||||
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
||||
InputServers: []string{"8.0.0.0:53", "8.0.0.1:53"},
|
||||
timeout: 200 * time.Millisecond,
|
||||
responseShouldBeNil: true,
|
||||
},
|
||||
{
|
||||
name: "Should Not Resolve If Parent Context Is Canceled",
|
||||
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
||||
InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"},
|
||||
cancelCTX: true,
|
||||
timeout: defaultUpstreamTimeout,
|
||||
responseShouldBeNil: true,
|
||||
},
|
||||
//{
|
||||
// name: "Should Resolve CNAME Record",
|
||||
// inputMSG: new(dns.Msg).SetQuestion("one.one.one.one", dns.TypeCNAME),
|
||||
//},
|
||||
//{
|
||||
// name: "Should Not Write When Not Found A Record",
|
||||
// inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA),
|
||||
// responseShouldBeNil: true,
|
||||
//},
|
||||
}
|
||||
// should resolve if first upstream times out
|
||||
// should not write when both fails
|
||||
// should not resolve if parent context is canceled
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
resolver := &upstreamResolver{
|
||||
parentCTX: ctx,
|
||||
upstreamClient: &dns.Client{},
|
||||
upstreamServers: testCase.InputServers,
|
||||
upstreamTimeout: testCase.timeout,
|
||||
}
|
||||
if testCase.cancelCTX {
|
||||
cancel()
|
||||
} else {
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &mockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
resolver.ServeDNS(responseWriter, testCase.inputMSG)
|
||||
|
||||
if responseMSG == nil {
|
||||
if testCase.responseShouldBeNil {
|
||||
return
|
||||
}
|
||||
t.Fatalf("should write a response message")
|
||||
}
|
||||
|
||||
foundAnswer := false
|
||||
for _, answer := range responseMSG.Answer {
|
||||
if strings.Contains(answer.String(), testCase.expectedAnswer) {
|
||||
foundAnswer = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !foundAnswer {
|
||||
t.Errorf("couldn't find the required answer, %s, in the dns response", testCase.expectedAnswer)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user