diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 2cf4bff0e..2dd8bdeac 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -42,7 +42,7 @@ type DefaultServer struct { ctx context.Context ctxCancel context.CancelFunc mux sync.Mutex - fakeResolverWG sync.WaitGroup + udpFilterHookID string server *dns.Server dnsMux *dns.ServeMux dnsMuxMap registeredHandlerMap @@ -105,7 +105,10 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd defaultServer.enabled = hasValidDnsServer(initialDnsCfg) } - defaultServer.evalRuntimeAddress() + if wgInterface.IsUserspaceBind() { + defaultServer.evelRuntimeAddressForUserspace() + } + return defaultServer, nil } @@ -118,6 +121,9 @@ func (s *DefaultServer) Initialize() (err error) { return nil } + if !s.wgInterface.IsUserspaceBind() { + s.evalRuntimeAddress() + } s.hostManager, err = newHostManager(s.wgInterface) return } @@ -126,17 +132,8 @@ func (s *DefaultServer) Initialize() (err error) { func (s *DefaultServer) listen() { // nil check required in unit tests if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() { - s.fakeResolverWG.Add(1) - go func() { - s.setListenerStatus(true) - defer s.setListenerStatus(false) - - hookID := s.filterDNSTraffic() - s.fakeResolverWG.Wait() - if err := s.wgInterface.GetFilter().RemovePacketHook(hookID); err != nil { - log.Errorf("unable to remove DNS packet hook: %s", err) - } - }() + s.udpFilterHookID = s.filterDNSTraffic() + s.setListenerStatus(true) return } @@ -153,6 +150,10 @@ func (s *DefaultServer) listen() { }() } +// DnsIP returns the DNS resolver server IP address +// +// When kernel space interface used it return real DNS server listener IP address +// For bind interface, fake DNS resolver address returned (second last IP address from Nebird network) func (s *DefaultServer) DnsIP() string { if !s.enabled { return "" @@ -201,10 +202,6 @@ func (s *DefaultServer) Stop() { } } - if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning { - s.fakeResolverWG.Done() - } - err := s.stopListener() if err != nil { log.Error(err) @@ -212,6 +209,18 @@ func (s *DefaultServer) Stop() { } func (s *DefaultServer) stopListener() error { + if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning { + // udpFilterHookID here empty only in the unit tests + if filter := s.wgInterface.GetFilter(); filter != nil && s.udpFilterHookID != "" { + if err := filter.RemovePacketHook(s.udpFilterHookID); err != nil { + log.Errorf("unable to remove DNS packet hook: %s", err) + } + } + s.udpFilterHookID = "" + s.listenerIsRunning = false + return nil + } + if !s.listenerIsRunning { return nil } @@ -275,12 +284,8 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { // is the service should be disabled, we stop the listener or fake resolver // and proceed with a regular update to clean up the handlers and records if !update.ServiceEnable { - if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() { - s.fakeResolverWG.Done() - } else { - if err := s.stopListener(); err != nil { - log.Error(err) - } + if err := s.stopListener(); err != nil { + log.Error(err) } } else if !s.listenerIsRunning { s.listen() @@ -555,17 +560,17 @@ func (s *DefaultServer) filterDNSTraffic() string { return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook) } +func (s *DefaultServer) evelRuntimeAddressForUserspace() { + s.runtimeIP = getLastIPFromNetwork(s.wgInterface.Address().Network, 1) + s.runtimePort = defaultPort + s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort) +} + func (s *DefaultServer) evalRuntimeAddress() { defer func() { s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort) }() - if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() { - s.runtimeIP = getLastIPFromNetwork(s.wgInterface.Address().Network, 1) - s.runtimePort = defaultPort - return - } - if s.customAddress != nil { s.runtimeIP = s.customAddress.Addr().String() s.runtimePort = int(s.customAddress.Port()) diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 46ab169fe..2b234fcc0 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -5,15 +5,18 @@ import ( "fmt" "net" "net/netip" + "os" "strings" "testing" "time" + "github.com/golang/mock/gomock" "github.com/miekg/dns" "github.com/netbirdio/netbird/client/internal/stdnet" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/iface" + pfmock "github.com/netbirdio/netbird/iface/mocks" ) var zoneRecords = []nbdns.SimpleRecord{ @@ -241,7 +244,6 @@ func TestUpdateDNSServer(t *testing.T) { dnsServer.updateSerial = testCase.initSerial // pretend we are running dnsServer.listenerIsRunning = true - dnsServer.fakeResolverWG.Add(1) err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate) if err != nil { @@ -276,6 +278,133 @@ func TestUpdateDNSServer(t *testing.T) { } } +func TestDNSFakeResolverHandleUpdates(t *testing.T) { + ov := os.Getenv("NB_WG_KERNEL_DISABLED") + defer os.Setenv("NB_WG_KERNEL_DISABLED", ov) + + os.Setenv("NB_WG_KERNEL_DISABLED", "true") + newNet, err := stdnet.NewNet(nil) + if err != nil { + t.Errorf("create stdnet: %v", err) + return + } + + wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", iface.DefaultMTU, nil, newNet) + if err != nil { + t.Errorf("build interface wireguard: %v", err) + return + } + + err = wgIface.Create() + if err != nil { + t.Errorf("crate and init wireguard interface: %v", err) + return + } + defer func() { + if err = wgIface.Close(); err != nil { + t.Logf("close wireguard interface: %v", err) + } + }() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + _, ipNet, err := net.ParseCIDR("100.66.100.1/32") + if err != nil { + t.Errorf("parse CIDR: %v", err) + return + } + + packetfilter := pfmock.NewMockPacketFilter(ctrl) + packetfilter.EXPECT().SetNetwork(ipNet) + packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes() + packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + packetfilter.EXPECT().RemovePacketHook(gomock.Any()).AnyTimes() + + if err := wgIface.SetFilter(packetfilter); err != nil { + t.Errorf("set packet filter: %v", err) + return + } + + dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", nil) + if err != nil { + t.Errorf("create DNS server: %v", err) + return + } + + err = dnsServer.Initialize() + if err != nil { + t.Errorf("run DNS server: %v", err) + return + } + defer func() { + if err = dnsServer.hostManager.restoreHostDNS(); err != nil { + t.Logf("restore DNS settings on the host: %v", err) + return + } + }() + + dnsServer.dnsMuxMap = registeredHandlerMap{zoneRecords[0].Name: &localResolver{}} + dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}} + dnsServer.updateSerial = 0 + + nameServers := []nbdns.NameServer{ + { + IP: netip.MustParseAddr("8.8.8.8"), + NSType: nbdns.UDPNameServerType, + Port: 53, + }, + { + IP: netip.MustParseAddr("8.8.4.4"), + NSType: nbdns.UDPNameServerType, + Port: 53, + }, + } + + update := nbdns.Config{ + ServiceEnable: true, + CustomZones: []nbdns.CustomZone{ + { + Domain: "netbird.cloud", + Records: zoneRecords, + }, + }, + NameServerGroups: []*nbdns.NameServerGroup{ + { + Domains: []string{"netbird.io"}, + NameServers: nameServers, + }, + { + NameServers: nameServers, + Primary: true, + }, + }, + } + + // Start the server with regular configuration + if err := dnsServer.UpdateDNSServer(1, update); err != nil { + t.Fatalf("update dns server should not fail, got error: %v", err) + return + } + + update2 := update + update2.ServiceEnable = false + // Disable the server, stop the listener + if err := dnsServer.UpdateDNSServer(2, update2); err != nil { + t.Fatalf("update dns server should not fail, got error: %v", err) + return + } + + update3 := update2 + update3.NameServerGroups = update3.NameServerGroups[:1] + // But service still get updates and we checking that we handle + // internal state in the right way + if err := dnsServer.UpdateDNSServer(3, update3); err != nil { + t.Fatalf("update dns server should not fail, got error: %v", err) + return + } +} + func TestDNSServerStartStop(t *testing.T) { testCases := []struct { name string