Feature/dns protocol (#543)

Added DNS update protocol message

Added sync to clients

Update nameserver API with new fields

Added default NS groups

Added new dns-name flag for the management service append to peer DNS label
This commit is contained in:
Maycon Santos
2022-11-07 15:38:21 +01:00
committed by GitHub
parent d0c6d88971
commit 270f0e4ce8
29 changed files with 1316 additions and 242 deletions

View File

@@ -68,7 +68,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
}
peersUpdateManager := mgmt.NewPeersUpdateManager()
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "")
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "")
if err != nil {
t.Fatal(err)
}

View File

@@ -0,0 +1,35 @@
package dns
import (
"fmt"
nbdns "github.com/netbirdio/netbird/dns"
)
// MockServer is the mock instance of a dns server
type MockServer struct {
StartFunc func()
StopFunc func()
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
}
// Start mock implementation of Start from Server interface
func (m *MockServer) Start() {
if m.StartFunc != nil {
m.StartFunc()
}
}
// Stop mock implementation of Stop from Server interface
func (m *MockServer) Stop() {
if m.StopFunc != nil {
m.StopFunc()
}
}
// UpdateDNSServer mock implementation of UpdateDNSServer from Server interface
func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
if m.UpdateDNSServerFunc != nil {
return m.UpdateDNSServerFunc(serial, update)
}
return fmt.Errorf("method UpdateDNSServer is not implemented")
}

View File

@@ -15,8 +15,15 @@ const (
defaultIP = "0.0.0.0"
)
// Server dns server object
type Server struct {
// Server is a dns server interface
type Server interface {
Start()
Stop()
UpdateDNSServer(serial uint64, update nbdns.Config) error
}
// DefaultServer dns server object
type DefaultServer struct {
ctx context.Context
stop context.CancelFunc
mux sync.Mutex
@@ -35,8 +42,8 @@ type muxUpdate struct {
handler dns.Handler
}
// NewServer returns a new dns server
func NewServer(ctx context.Context) *Server {
// NewDefaultServer returns a new dns server
func NewDefaultServer(ctx context.Context) *DefaultServer {
mux := dns.NewServeMux()
dnsServer := &dns.Server{
@@ -48,7 +55,7 @@ func NewServer(ctx context.Context) *Server {
ctx, stop := context.WithCancel(ctx)
return &Server{
return &DefaultServer{
ctx: ctx,
stop: stop,
server: dnsServer,
@@ -61,7 +68,7 @@ func NewServer(ctx context.Context) *Server {
}
// Start runs the listener in a go routine
func (s *Server) Start() {
func (s *DefaultServer) Start() {
log.Debugf("starting dns on %s:%d", defaultIP, port)
go func() {
s.setListenerStatus(true)
@@ -73,12 +80,12 @@ func (s *Server) Start() {
}()
}
func (s *Server) setListenerStatus(running bool) {
func (s *DefaultServer) setListenerStatus(running bool) {
s.listenerIsRunning = running
}
// Stop stops the server
func (s *Server) Stop() {
func (s *DefaultServer) Stop() {
s.stop()
err := s.stopListener()
@@ -87,7 +94,7 @@ func (s *Server) Stop() {
}
}
func (s *Server) stopListener() error {
func (s *DefaultServer) stopListener() error {
if !s.listenerIsRunning {
return nil
}
@@ -103,7 +110,7 @@ func (s *Server) stopListener() error {
}
// UpdateDNSServer processes an update received from the management service
func (s *Server) UpdateDNSServer(serial uint64, update nbdns.Update) error {
func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
select {
case <-s.ctx.Done():
log.Infof("not updating DNS server as context is closed")
@@ -147,7 +154,7 @@ func (s *Server) UpdateDNSServer(serial uint64, update nbdns.Update) error {
}
}
func (s *Server) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) {
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) {
var muxUpdates []muxUpdate
localRecords := make(map[string]nbdns.SimpleRecord, 0)
@@ -169,7 +176,7 @@ func (s *Server) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxU
return muxUpdates, localRecords, nil
}
func (s *Server) buildUpstreamHandlerUpdate(nameServerGroups []nbdns.NameServerGroup) ([]muxUpdate, error) {
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) {
var muxUpdates []muxUpdate
for _, nsGroup := range nameServerGroups {
if len(nsGroup.NameServers) == 0 {
@@ -219,7 +226,7 @@ func (s *Server) buildUpstreamHandlerUpdate(nameServerGroups []nbdns.NameServerG
return muxUpdates, nil
}
func (s *Server) updateMux(muxUpdates []muxUpdate) {
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
muxUpdateMap := make(registrationMap)
for _, update := range muxUpdates {
@@ -237,7 +244,7 @@ func (s *Server) updateMux(muxUpdates []muxUpdate) {
s.dnsMuxMap = muxUpdateMap
}
func (s *Server) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
for key := range s.localResolver.registeredMap {
_, found := update[key]
if !found {
@@ -261,10 +268,10 @@ func getNSHostPort(ns nbdns.NameServer) string {
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
}
func (s *Server) registerMux(pattern string, handler dns.Handler) {
func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) {
s.dnsMux.Handle(pattern, handler)
}
func (s *Server) deregisterMux(pattern string) {
func (s *DefaultServer) deregisterMux(pattern string) {
s.dnsMux.HandleRemove(pattern)
}

View File

@@ -43,18 +43,18 @@ func TestUpdateDNSServer(t *testing.T) {
initLocalMap registrationMap
initSerial uint64
inputSerial uint64
inputUpdate nbdns.Update
inputUpdate nbdns.Config
shouldFail bool
expectedUpstreamMap registrationMap
expectedLocalMap registrationMap
}{
{
name: "Initial Update Should Succeed",
name: "Initial Config Should Succeed",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registrationMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Update{
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
@@ -62,7 +62,7 @@ func TestUpdateDNSServer(t *testing.T) {
Records: zoneRecords,
},
},
NameServerGroups: []nbdns.NameServerGroup{
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
@@ -77,12 +77,12 @@ func TestUpdateDNSServer(t *testing.T) {
expectedLocalMap: registrationMap{zoneRecords[0].Name: struct{}{}},
},
{
name: "New Update Should Succeed",
name: "New Config Should Succeed",
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
initUpstreamMap: registrationMap{zoneRecords[0].Name: struct{}{}},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Update{
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
@@ -90,7 +90,7 @@ func TestUpdateDNSServer(t *testing.T) {
Records: zoneRecords,
},
},
NameServerGroups: []nbdns.NameServerGroup{
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
@@ -101,7 +101,7 @@ func TestUpdateDNSServer(t *testing.T) {
expectedLocalMap: registrationMap{zoneRecords[0].Name: struct{}{}},
},
{
name: "Smaller Update Serial Should Be Skipped",
name: "Smaller Config Serial Should Be Skipped",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registrationMap),
initSerial: 2,
@@ -114,7 +114,7 @@ func TestUpdateDNSServer(t *testing.T) {
initUpstreamMap: make(registrationMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Update{
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
@@ -122,7 +122,7 @@ func TestUpdateDNSServer(t *testing.T) {
Records: zoneRecords,
},
},
NameServerGroups: []nbdns.NameServerGroup{
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
},
@@ -136,7 +136,7 @@ func TestUpdateDNSServer(t *testing.T) {
initUpstreamMap: make(registrationMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Update{
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
@@ -144,7 +144,7 @@ func TestUpdateDNSServer(t *testing.T) {
Records: zoneRecords,
},
},
NameServerGroups: []nbdns.NameServerGroup{
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
},
@@ -158,14 +158,14 @@ func TestUpdateDNSServer(t *testing.T) {
initUpstreamMap: make(registrationMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Update{
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
},
},
NameServerGroups: []nbdns.NameServerGroup{
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
Primary: true,
@@ -175,12 +175,12 @@ func TestUpdateDNSServer(t *testing.T) {
shouldFail: true,
},
{
name: "Empty Update Should Succeed and Clean Maps",
name: "Empty Config Should Succeed and Clean Maps",
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
initUpstreamMap: registrationMap{zoneRecords[0].Name: struct{}{}},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Update{ServiceEnable: true},
inputUpdate: nbdns.Config{ServiceEnable: true},
expectedUpstreamMap: make(registrationMap),
expectedLocalMap: make(registrationMap),
},
@@ -189,7 +189,7 @@ func TestUpdateDNSServer(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
ctx := context.Background()
dnsServer := NewServer(ctx)
dnsServer := NewDefaultServer(ctx)
dnsServer.dnsMuxMap = testCase.initUpstreamMap
dnsServer.localResolver.registeredMap = testCase.initLocalMap
@@ -231,7 +231,7 @@ func TestUpdateDNSServer(t *testing.T) {
func TestDNSServerStartStop(t *testing.T) {
ctx := context.Background()
dnsServer := NewServer(ctx)
dnsServer := NewDefaultServer(ctx)
if runtime.GOOS == "windows" && os.Getenv("CI") == "true" {
// todo review why this test is not working only on github actions workflows
t.Skip("skipping test in Windows CI workflows.")

View File

@@ -7,9 +7,11 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager"
nbssh "github.com/netbirdio/netbird/client/ssh"
nbstatus "github.com/netbirdio/netbird/client/status"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
"math/rand"
"net"
"net/netip"
"reflect"
"runtime"
"strings"
@@ -105,7 +107,7 @@ type Engine struct {
routeManager routemanager.Manager
dnsServer *dns.Server
dnsServer dns.Server
}
// Peer is an instance of the Connection Peer
@@ -133,7 +135,7 @@ func NewEngine(
networkSerial: 0,
sshServerFunc: nbssh.DefaultSSHServer,
statusRecorder: statusRecorder,
dnsServer: dns.NewServer(ctx),
dnsServer: dns.NewDefaultServer(ctx),
}
}
@@ -646,6 +648,15 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
log.Errorf("failed to update routes, err: %v", err)
}
protoDNSConfig := networkMap.GetDNSConfig()
if protoDNSConfig == nil {
protoDNSConfig = &mgmProto.DNSConfig{}
}
err = e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig))
if err != nil {
log.Errorf("failed to update dns server, err: %v", err)
}
e.networkSerial = serial
return nil
}
@@ -668,6 +679,48 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
return routes
}
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
dnsUpdate := nbdns.Config{
ServiceEnable: protoDNSConfig.GetServiceEnable(),
CustomZones: make([]nbdns.CustomZone, 0),
NameServerGroups: make([]*nbdns.NameServerGroup, 0),
}
for _, zone := range protoDNSConfig.GetCustomZones() {
dnsZone := nbdns.CustomZone{
Domain: zone.GetDomain(),
}
for _, record := range zone.Records {
dnsRecord := nbdns.SimpleRecord{
Name: record.GetName(),
Type: int(record.GetType()),
Class: record.GetClass(),
TTL: int(record.GetTTL()),
RData: record.GetRData(),
}
dnsZone.Records = append(dnsZone.Records, dnsRecord)
}
dnsUpdate.CustomZones = append(dnsUpdate.CustomZones, dnsZone)
}
for _, nsGroup := range protoDNSConfig.GetNameServerGroups() {
dnsNSGroup := &nbdns.NameServerGroup{
Primary: nsGroup.GetPrimary(),
Domains: nsGroup.GetDomains(),
}
for _, ns := range nsGroup.GetNameServers() {
dnsNS := nbdns.NameServer{
IP: netip.MustParseAddr(ns.GetIP()),
NSType: nbdns.NameServerType(ns.GetNSType()),
Port: int(ns.GetPort()),
}
dnsNSGroup.NameServers = append(dnsNSGroup.NameServers, dnsNS)
}
dnsUpdate.NameServerGroups = append(dnsUpdate.NameServerGroups, dnsNSGroup)
}
return dnsUpdate
}
// addNewPeers adds peers that were not know before but arrived from the Management service with the update
func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
for _, p := range peersUpdate {

View File

@@ -3,9 +3,11 @@ package internal
import (
"context"
"fmt"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/ssh"
nbstatus "github.com/netbirdio/netbird/client/status"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
"github.com/stretchr/testify/assert"
@@ -440,7 +442,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
expectedSerial uint64
}{
{
name: "Routes Update Should Be Passed To Manager",
name: "Routes Config Should Be Passed To Manager",
networkMap: &mgmtProto.NetworkMap{
Serial: 1,
PeerConfig: nil,
@@ -486,7 +488,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
expectedSerial: 1,
},
{
name: "Empty Routes Update Should Be Passed",
name: "Empty Routes Config Should Be Passed",
networkMap: &mgmtProto.NetworkMap{
Serial: 1,
PeerConfig: nil,
@@ -566,6 +568,183 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
}
}
func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
testCases := []struct {
name string
inputErr error
networkMap *mgmtProto.NetworkMap
expectedZonesLen int
expectedZones []nbdns.CustomZone
expectedNSGroupsLen int
expectedNSGroups []*nbdns.NameServerGroup
expectedSerial uint64
}{
{
name: "DNS Config Should Be Passed To DNS Server",
networkMap: &mgmtProto.NetworkMap{
Serial: 1,
PeerConfig: nil,
RemotePeersIsEmpty: false,
Routes: nil,
DNSConfig: &mgmtProto.DNSConfig{
ServiceEnable: true,
CustomZones: []*mgmtProto.CustomZone{
{
Domain: "netbird.cloud.",
Records: []*mgmtProto.SimpleRecord{
{
Name: "peer-a.netbird.cloud.",
Type: 1,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "100.64.0.1",
},
},
},
},
NameServerGroups: []*mgmtProto.NameServerGroup{
{
Primary: true,
NameServers: []*mgmtProto.NameServer{
{
IP: "8.8.8.8",
NSType: 1,
Port: 53,
},
},
},
},
},
},
expectedZonesLen: 1,
expectedZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud.",
Records: []nbdns.SimpleRecord{
{
Name: "peer-a.netbird.cloud.",
Type: 1,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "100.64.0.1",
},
},
},
},
expectedNSGroupsLen: 1,
expectedNSGroups: []*nbdns.NameServerGroup{
{
Primary: true,
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: 1,
Port: 53,
},
},
},
},
expectedSerial: 1,
},
{
name: "Empty DNS Config Should Be OK",
networkMap: &mgmtProto.NetworkMap{
Serial: 1,
PeerConfig: nil,
RemotePeersIsEmpty: false,
Routes: nil,
DNSConfig: nil,
},
expectedZonesLen: 0,
expectedZones: []nbdns.CustomZone{},
expectedNSGroupsLen: 0,
expectedNSGroups: []*nbdns.NameServerGroup{},
expectedSerial: 1,
},
{
name: "Error Shouldn't Break Engine",
inputErr: fmt.Errorf("mocking error"),
networkMap: &mgmtProto.NetworkMap{
Serial: 1,
PeerConfig: nil,
RemotePeersIsEmpty: false,
Routes: nil,
},
expectedZonesLen: 0,
expectedZones: []nbdns.CustomZone{},
expectedNSGroupsLen: 0,
expectedNSGroups: []*nbdns.NameServerGroup{},
expectedSerial: 1,
},
}
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
// test setup
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
return
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
WgIfaceName: wgIfaceName,
WgAddr: wgAddr,
WgPrivateKey: key,
WgPort: 33100,
}, nbstatus.NewRecorder())
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU)
assert.NoError(t, err, "shouldn't return error")
mockRouteManager := &routemanager.MockManager{
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
return nil
},
}
engine.routeManager = mockRouteManager
input := struct {
inputSerial uint64
inputNSGroups []*nbdns.NameServerGroup
inputZones []nbdns.CustomZone
}{}
mockDNSServer := &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error {
input.inputSerial = serial
input.inputZones = update.CustomZones
input.inputNSGroups = update.NameServerGroups
return testCase.inputErr
},
}
engine.dnsServer = mockDNSServer
defer func() {
exitErr := engine.Stop()
if exitErr != nil {
return
}
}()
err = engine.updateNetworkMap(testCase.networkMap)
assert.NoError(t, err, "shouldn't return error")
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
assert.Len(t, input.inputNSGroups, testCase.expectedZonesLen, "zones len should match")
assert.Equal(t, testCase.expectedZones, input.inputZones, "custom zones should match")
assert.Len(t, input.inputNSGroups, testCase.expectedNSGroupsLen, "ns groups len should match")
assert.Equal(t, testCase.expectedNSGroups, input.inputNSGroups, "ns groups should match")
})
}
}
func TestEngine_MultiplePeers(t *testing.T) {
// log.SetLevel(log.DebugLevel)
@@ -761,7 +940,7 @@ func startManagement(port int, dataDir string) (*grpc.Server, error) {
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
}
peersUpdateManager := server.NewPeersUpdateManager()
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "")
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "")
if err != nil {
return nil, err
}