mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-04 16:16:40 +00:00
Compare commits
3 Commits
dependabot
...
drop-dns-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
db2a62bf29 | ||
|
|
d0f9d80c3a | ||
|
|
c102592735 |
@@ -113,7 +113,6 @@ func (c *ConnectClient) RunOniOS(
|
|||||||
fileDescriptor int32,
|
fileDescriptor int32,
|
||||||
networkChangeListener listener.NetworkChangeListener,
|
networkChangeListener listener.NetworkChangeListener,
|
||||||
dnsManager dns.IosDnsManager,
|
dnsManager dns.IosDnsManager,
|
||||||
dnsAddresses []netip.AddrPort,
|
|
||||||
stateFilePath string,
|
stateFilePath string,
|
||||||
) error {
|
) error {
|
||||||
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
|
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
|
||||||
@@ -123,7 +122,6 @@ func (c *ConnectClient) RunOniOS(
|
|||||||
FileDescriptor: fileDescriptor,
|
FileDescriptor: fileDescriptor,
|
||||||
NetworkChangeListener: networkChangeListener,
|
NetworkChangeListener: networkChangeListener,
|
||||||
DnsManager: dnsManager,
|
DnsManager: dnsManager,
|
||||||
HostDNSAddresses: dnsAddresses,
|
|
||||||
StateFilePath: stateFilePath,
|
StateFilePath: stateFilePath,
|
||||||
}
|
}
|
||||||
return c.run(mobileDependency, nil, "")
|
return c.run(mobileDependency, nil, "")
|
||||||
|
|||||||
@@ -16,6 +16,10 @@ type hostManager interface {
|
|||||||
restoreHostDNS() error
|
restoreHostDNS() error
|
||||||
supportCustomPort() bool
|
supportCustomPort() bool
|
||||||
string() string
|
string() string
|
||||||
|
// getOriginalNameservers returns the OS-side resolvers used as PriorityFallback
|
||||||
|
// upstreams: pre-takeover snapshots on desktop, the OS-pushed list on Android,
|
||||||
|
// hardcoded Quad9 on iOS, nil for noop / mock.
|
||||||
|
getOriginalNameservers() []netip.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
type SystemDNSSettings struct {
|
type SystemDNSSettings struct {
|
||||||
@@ -131,3 +135,11 @@ func (n noopHostConfigurator) supportCustomPort() bool {
|
|||||||
func (n noopHostConfigurator) string() string {
|
func (n noopHostConfigurator) string() string {
|
||||||
return "noop"
|
return "noop"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (n noopHostConfigurator) getOriginalNameservers() []netip.Addr {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockHostConfigurator) getOriginalNameservers() []netip.Addr {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,28 +1,43 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// androidHostManager is a noop on the OS side (Android's VPN service handles
|
||||||
|
// DNS for us) but tracks the OS-reported resolver list pushed via
|
||||||
|
// OnUpdatedHostDNSServer so it can serve as the fallback nameserver source.
|
||||||
type androidHostManager struct {
|
type androidHostManager struct {
|
||||||
|
holder *hostsDNSHolder
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager() (*androidHostManager, error) {
|
func newHostManager(holder *hostsDNSHolder) (*androidHostManager, error) {
|
||||||
return &androidHostManager{}, nil
|
return &androidHostManager{holder: holder}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a androidHostManager) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error {
|
func (a *androidHostManager) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a androidHostManager) restoreHostDNS() error {
|
func (a *androidHostManager) restoreHostDNS() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a androidHostManager) supportCustomPort() bool {
|
func (a *androidHostManager) supportCustomPort() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a androidHostManager) string() string {
|
func (a *androidHostManager) string() string {
|
||||||
return "none"
|
return "none"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *androidHostManager) getOriginalNameservers() []netip.Addr {
|
||||||
|
hosts := a.holder.get()
|
||||||
|
out := make([]netip.Addr, 0, len(hosts))
|
||||||
|
for ap := range hosts {
|
||||||
|
out = append(out, ap.Addr())
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
@@ -14,6 +15,14 @@ type iosHostManager struct {
|
|||||||
config HostDNSConfig
|
config HostDNSConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a iosHostManager) getOriginalNameservers() []netip.Addr {
|
||||||
|
// Quad9 v4+v6: 9.9.9.9, 2620:fe::fe.
|
||||||
|
return []netip.Addr{
|
||||||
|
netip.AddrFrom4([4]byte{9, 9, 9, 9}),
|
||||||
|
netip.AddrFrom16([16]byte{0x26, 0x20, 0x00, 0xfe, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xfe}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func newHostManager(dnsManager IosDnsManager) (*iosHostManager, error) {
|
func newHostManager(dnsManager IosDnsManager) (*iosHostManager, error) {
|
||||||
return &iosHostManager{
|
return &iosHostManager{
|
||||||
dnsManager: dnsManager,
|
dnsManager: dnsManager,
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
@@ -45,7 +46,9 @@ const (
|
|||||||
nrptMaxDomainsPerRule = 50
|
nrptMaxDomainsPerRule = 50
|
||||||
|
|
||||||
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
|
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
|
||||||
|
interfaceConfigPathV6 = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces`
|
||||||
interfaceConfigNameServerKey = "NameServer"
|
interfaceConfigNameServerKey = "NameServer"
|
||||||
|
interfaceConfigDhcpNameSrvKey = "DhcpNameServer"
|
||||||
interfaceConfigSearchListKey = "SearchList"
|
interfaceConfigSearchListKey = "SearchList"
|
||||||
|
|
||||||
// Network interface DNS registration settings
|
// Network interface DNS registration settings
|
||||||
@@ -71,6 +74,7 @@ type registryConfigurator struct {
|
|||||||
routingAll bool
|
routingAll bool
|
||||||
gpo bool
|
gpo bool
|
||||||
nrptEntryCount int
|
nrptEntryCount int
|
||||||
|
origNameservers []netip.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
||||||
@@ -94,6 +98,17 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
|||||||
gpo: useGPO,
|
gpo: useGPO,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
origNameservers, err := configurator.captureOriginalNameservers()
|
||||||
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
log.Warnf("capture original nameservers from non-WG adapters: %v", err)
|
||||||
|
case len(origNameservers) == 0:
|
||||||
|
log.Warnf("no original nameservers captured from non-WG adapters; DNS fallback will be empty")
|
||||||
|
default:
|
||||||
|
log.Debugf("captured %d original nameservers from non-WG adapters: %v", len(origNameservers), origNameservers)
|
||||||
|
}
|
||||||
|
configurator.origNameservers = origNameservers
|
||||||
|
|
||||||
if err := configurator.configureInterface(); err != nil {
|
if err := configurator.configureInterface(); err != nil {
|
||||||
log.Errorf("failed to configure interface settings: %v", err)
|
log.Errorf("failed to configure interface settings: %v", err)
|
||||||
}
|
}
|
||||||
@@ -101,6 +116,98 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
|||||||
return configurator, nil
|
return configurator, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// captureOriginalNameservers reads DNS addresses from every Tcpip(6) interface
|
||||||
|
// registry key except the WG adapter. v4 and v6 servers live in separate
|
||||||
|
// hives (Tcpip vs Tcpip6) keyed by the same interface GUID.
|
||||||
|
func (r *registryConfigurator) captureOriginalNameservers() ([]netip.Addr, error) {
|
||||||
|
seen := make(map[netip.Addr]struct{})
|
||||||
|
var out []netip.Addr
|
||||||
|
var merr *multierror.Error
|
||||||
|
for _, root := range []string{interfaceConfigPath, interfaceConfigPathV6} {
|
||||||
|
addrs, err := r.captureFromTcpipRoot(root)
|
||||||
|
if err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("%s: %w", root, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, addr := range addrs {
|
||||||
|
if _, dup := seen[addr]; dup {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[addr] = struct{}{}
|
||||||
|
out = append(out, addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *registryConfigurator) captureFromTcpipRoot(rootPath string) ([]netip.Addr, error) {
|
||||||
|
root, err := registry.OpenKey(registry.LOCAL_MACHINE, rootPath, registry.READ)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("open key: %w", err)
|
||||||
|
}
|
||||||
|
defer closer(root)
|
||||||
|
|
||||||
|
guids, err := root.ReadSubKeyNames(-1)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read subkeys: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var out []netip.Addr
|
||||||
|
for _, guid := range guids {
|
||||||
|
if strings.EqualFold(guid, r.guid) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, readInterfaceNameservers(rootPath, guid)...)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readInterfaceNameservers(rootPath, guid string) []netip.Addr {
|
||||||
|
keyPath := rootPath + "\\" + guid
|
||||||
|
k, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.QUERY_VALUE)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
defer closer(k)
|
||||||
|
|
||||||
|
// Static NameServer wins over DhcpNameServer for actual resolution.
|
||||||
|
for _, name := range []string{interfaceConfigNameServerKey, interfaceConfigDhcpNameSrvKey} {
|
||||||
|
raw, _, err := k.GetStringValue(name)
|
||||||
|
if err != nil || raw == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if out := parseRegistryNameservers(raw); len(out) > 0 {
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseRegistryNameservers(raw string) []netip.Addr {
|
||||||
|
var out []netip.Addr
|
||||||
|
for _, field := range strings.FieldsFunc(raw, func(r rune) bool { return r == ',' || r == ' ' || r == '\t' }) {
|
||||||
|
addr, err := netip.ParseAddr(strings.TrimSpace(field))
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
addr = addr.Unmap()
|
||||||
|
if !addr.IsValid() || addr.IsUnspecified() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Drop unzoned link-local: not routable without a scope id. If
|
||||||
|
// the user wrote "fe80::1%eth0" ParseAddr preserves the zone.
|
||||||
|
if addr.IsLinkLocalUnicast() && addr.Zone() == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, addr)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *registryConfigurator) getOriginalNameservers() []netip.Addr {
|
||||||
|
return slices.Clone(r.origNameservers)
|
||||||
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) supportCustomPort() bool {
|
func (r *registryConfigurator) supportCustomPort() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ func (h *hostsDNSHolder) set(list []netip.AddrPort) {
|
|||||||
h.mutex.Unlock()
|
h.mutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//nolint:unused
|
||||||
func (h *hostsDNSHolder) get() map[netip.AddrPort]struct{} {
|
func (h *hostsDNSHolder) get() map[netip.AddrPort]struct{} {
|
||||||
h.mutex.RLock()
|
h.mutex.RLock()
|
||||||
l := h.unprotectedDNSList
|
l := h.unprotectedDNSList
|
||||||
|
|||||||
@@ -77,8 +77,6 @@ func (d *Resolver) ID() types.HandlerID {
|
|||||||
return "local-resolver"
|
return "local-resolver"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Resolver) ProbeAvailability(context.Context) {}
|
|
||||||
|
|
||||||
// ServeDNS handles a DNS request
|
// ServeDNS handles a DNS request
|
||||||
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
logger := log.WithFields(log.Fields{
|
logger := log.WithFields(log.Fields{
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -70,10 +71,6 @@ func (m *MockServer) SearchDomains() []string {
|
|||||||
return make([]string, 0)
|
return make([]string, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProbeAvailability mocks implementation of ProbeAvailability from the Server interface
|
|
||||||
func (m *MockServer) ProbeAvailability() {
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
|
func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
|
||||||
if m.UpdateServerConfigFunc != nil {
|
if m.UpdateServerConfigFunc != nil {
|
||||||
return m.UpdateServerConfigFunc(domains)
|
return m.UpdateServerConfigFunc(domains)
|
||||||
@@ -85,8 +82,8 @@ func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetRouteChecker mock implementation of SetRouteChecker from Server interface
|
// SetRouteSources mock implementation of SetRouteSources from Server interface
|
||||||
func (m *MockServer) SetRouteChecker(func(netip.Addr) bool) {
|
func (m *MockServer) SetRouteSources(selected, active func() route.HAMap) {
|
||||||
// Mock implementation - no-op
|
// Mock implementation - no-op
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -32,6 +33,15 @@ const (
|
|||||||
networkManagerDbusDeviceGetAppliedConnectionMethod = networkManagerDbusDeviceInterface + ".GetAppliedConnection"
|
networkManagerDbusDeviceGetAppliedConnectionMethod = networkManagerDbusDeviceInterface + ".GetAppliedConnection"
|
||||||
networkManagerDbusDeviceReapplyMethod = networkManagerDbusDeviceInterface + ".Reapply"
|
networkManagerDbusDeviceReapplyMethod = networkManagerDbusDeviceInterface + ".Reapply"
|
||||||
networkManagerDbusDeviceDeleteMethod = networkManagerDbusDeviceInterface + ".Delete"
|
networkManagerDbusDeviceDeleteMethod = networkManagerDbusDeviceInterface + ".Delete"
|
||||||
|
networkManagerDbusDeviceIp4ConfigProperty = networkManagerDbusDeviceInterface + ".Ip4Config"
|
||||||
|
networkManagerDbusDeviceIp6ConfigProperty = networkManagerDbusDeviceInterface + ".Ip6Config"
|
||||||
|
networkManagerDbusDeviceIfaceProperty = networkManagerDbusDeviceInterface + ".Interface"
|
||||||
|
networkManagerDbusGetDevicesMethod = networkManagerDest + ".GetDevices"
|
||||||
|
networkManagerDbusIp4ConfigInterface = "org.freedesktop.NetworkManager.IP4Config"
|
||||||
|
networkManagerDbusIp6ConfigInterface = "org.freedesktop.NetworkManager.IP6Config"
|
||||||
|
networkManagerDbusIp4ConfigNameserverDataProperty = networkManagerDbusIp4ConfigInterface + ".NameserverData"
|
||||||
|
networkManagerDbusIp4ConfigNameserversProperty = networkManagerDbusIp4ConfigInterface + ".Nameservers"
|
||||||
|
networkManagerDbusIp6ConfigNameserversProperty = networkManagerDbusIp6ConfigInterface + ".Nameservers"
|
||||||
networkManagerDbusDefaultBehaviorFlag networkManagerConfigBehavior = 0
|
networkManagerDbusDefaultBehaviorFlag networkManagerConfigBehavior = 0
|
||||||
networkManagerDbusIPv4Key = "ipv4"
|
networkManagerDbusIPv4Key = "ipv4"
|
||||||
networkManagerDbusIPv6Key = "ipv6"
|
networkManagerDbusIPv6Key = "ipv6"
|
||||||
@@ -54,6 +64,7 @@ type networkManagerDbusConfigurator struct {
|
|||||||
dbusLinkObject dbus.ObjectPath
|
dbusLinkObject dbus.ObjectPath
|
||||||
routingAll bool
|
routingAll bool
|
||||||
ifaceName string
|
ifaceName string
|
||||||
|
origNameservers []netip.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
// the types below are based on dbus specification, each field is mapped to a dbus type
|
// the types below are based on dbus specification, each field is mapped to a dbus type
|
||||||
@@ -92,10 +103,200 @@ func newNetworkManagerDbusConfigurator(wgInterface string) (*networkManagerDbusC
|
|||||||
|
|
||||||
log.Debugf("got network manager dbus Link Object: %s from net interface %s", s, wgInterface)
|
log.Debugf("got network manager dbus Link Object: %s from net interface %s", s, wgInterface)
|
||||||
|
|
||||||
return &networkManagerDbusConfigurator{
|
c := &networkManagerDbusConfigurator{
|
||||||
dbusLinkObject: dbus.ObjectPath(s),
|
dbusLinkObject: dbus.ObjectPath(s),
|
||||||
ifaceName: wgInterface,
|
ifaceName: wgInterface,
|
||||||
}, nil
|
}
|
||||||
|
|
||||||
|
origNameservers, err := c.captureOriginalNameservers()
|
||||||
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
log.Warnf("capture original nameservers from NetworkManager: %v", err)
|
||||||
|
case len(origNameservers) == 0:
|
||||||
|
log.Warnf("no original nameservers captured from non-WG NetworkManager devices; DNS fallback will be empty")
|
||||||
|
default:
|
||||||
|
log.Debugf("captured %d original nameservers from non-WG NetworkManager devices: %v", len(origNameservers), origNameservers)
|
||||||
|
}
|
||||||
|
c.origNameservers = origNameservers
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// captureOriginalNameservers reads DNS servers from every NM device's
|
||||||
|
// IP4Config / IP6Config except our WG device.
|
||||||
|
func (n *networkManagerDbusConfigurator) captureOriginalNameservers() ([]netip.Addr, error) {
|
||||||
|
devices, err := networkManagerListDevices()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("list devices: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
seen := make(map[netip.Addr]struct{})
|
||||||
|
var out []netip.Addr
|
||||||
|
for _, dev := range devices {
|
||||||
|
if dev == n.dbusLinkObject {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ifaceName := readNetworkManagerDeviceInterface(dev)
|
||||||
|
for _, addr := range readNetworkManagerDeviceDNS(dev) {
|
||||||
|
addr = addr.Unmap()
|
||||||
|
if !addr.IsValid() || addr.IsUnspecified() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// IP6Config.Nameservers is a byte slice without zone info;
|
||||||
|
// reattach the device's interface name so a captured fe80::…
|
||||||
|
// stays routable.
|
||||||
|
if addr.IsLinkLocalUnicast() && ifaceName != "" {
|
||||||
|
addr = addr.WithZone(ifaceName)
|
||||||
|
}
|
||||||
|
if _, dup := seen[addr]; dup {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[addr] = struct{}{}
|
||||||
|
out = append(out, addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readNetworkManagerDeviceInterface(devicePath dbus.ObjectPath) string {
|
||||||
|
obj, closeConn, err := getDbusObject(networkManagerDest, devicePath)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
defer closeConn()
|
||||||
|
v, err := obj.GetProperty(networkManagerDbusDeviceIfaceProperty)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
s, _ := v.Value().(string)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func networkManagerListDevices() ([]dbus.ObjectPath, error) {
|
||||||
|
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("dbus NetworkManager: %w", err)
|
||||||
|
}
|
||||||
|
defer closeConn()
|
||||||
|
var devs []dbus.ObjectPath
|
||||||
|
if err := obj.Call(networkManagerDbusGetDevicesMethod, dbusDefaultFlag).Store(&devs); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return devs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readNetworkManagerDeviceDNS(devicePath dbus.ObjectPath) []netip.Addr {
|
||||||
|
obj, closeConn, err := getDbusObject(networkManagerDest, devicePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
defer closeConn()
|
||||||
|
|
||||||
|
var out []netip.Addr
|
||||||
|
if path := readNetworkManagerConfigPath(obj, networkManagerDbusDeviceIp4ConfigProperty); path != "" {
|
||||||
|
out = append(out, readIPv4ConfigDNS(path)...)
|
||||||
|
}
|
||||||
|
if path := readNetworkManagerConfigPath(obj, networkManagerDbusDeviceIp6ConfigProperty); path != "" {
|
||||||
|
out = append(out, readIPv6ConfigDNS(path)...)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func readNetworkManagerConfigPath(obj dbus.BusObject, property string) dbus.ObjectPath {
|
||||||
|
v, err := obj.GetProperty(property)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
path, ok := v.Value().(dbus.ObjectPath)
|
||||||
|
if !ok || path == "/" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
|
||||||
|
func readIPv4ConfigDNS(path dbus.ObjectPath) []netip.Addr {
|
||||||
|
obj, closeConn, err := getDbusObject(networkManagerDest, path)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
defer closeConn()
|
||||||
|
|
||||||
|
// NameserverData (NM 1.13+) carries strings; older NMs only expose the
|
||||||
|
// legacy uint32 Nameservers property.
|
||||||
|
if out := readIPv4NameserverData(obj); len(out) > 0 {
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
return readIPv4LegacyNameservers(obj)
|
||||||
|
}
|
||||||
|
|
||||||
|
func readIPv4NameserverData(obj dbus.BusObject) []netip.Addr {
|
||||||
|
v, err := obj.GetProperty(networkManagerDbusIp4ConfigNameserverDataProperty)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
entries, ok := v.Value().([]map[string]dbus.Variant)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var out []netip.Addr
|
||||||
|
for _, entry := range entries {
|
||||||
|
addrVar, ok := entry["address"]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s, ok := addrVar.Value().(string)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if a, err := netip.ParseAddr(s); err == nil {
|
||||||
|
out = append(out, a)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func readIPv4LegacyNameservers(obj dbus.BusObject) []netip.Addr {
|
||||||
|
v, err := obj.GetProperty(networkManagerDbusIp4ConfigNameserversProperty)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
raw, ok := v.Value().([]uint32)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]netip.Addr, 0, len(raw))
|
||||||
|
for _, n := range raw {
|
||||||
|
var b [4]byte
|
||||||
|
binary.LittleEndian.PutUint32(b[:], n)
|
||||||
|
out = append(out, netip.AddrFrom4(b))
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func readIPv6ConfigDNS(path dbus.ObjectPath) []netip.Addr {
|
||||||
|
obj, closeConn, err := getDbusObject(networkManagerDest, path)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
defer closeConn()
|
||||||
|
v, err := obj.GetProperty(networkManagerDbusIp6ConfigNameserversProperty)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
raw, ok := v.Value().([][]byte)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]netip.Addr, 0, len(raw))
|
||||||
|
for _, b := range raw {
|
||||||
|
if a, ok := netip.AddrFromSlice(b); ok {
|
||||||
|
out = append(out, a)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *networkManagerDbusConfigurator) getOriginalNameservers() []netip.Addr {
|
||||||
|
return slices.Clone(n.origNameservers)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *networkManagerDbusConfigurator) supportCustomPort() bool {
|
func (n *networkManagerDbusConfigurator) supportCustomPort() bool {
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
func (s *DefaultServer) initialize() (manager hostManager, err error) {
|
func (s *DefaultServer) initialize() (manager hostManager, err error) {
|
||||||
return newHostManager()
|
return newHostManager(s.hostsDNSHolder)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"runtime"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/mock"
|
"github.com/stretchr/testify/mock"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
@@ -31,8 +32,10 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/formatter"
|
"github.com/netbirdio/netbird/formatter"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -101,16 +104,17 @@ func init() {
|
|||||||
formatter.SetTextFormatter(log.StandardLogger())
|
formatter.SetTextFormatter(log.StandardLogger())
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase {
|
func generateDummyHandler(d string, servers []nbdns.NameServer) *upstreamResolverBase {
|
||||||
var srvs []netip.AddrPort
|
var srvs []netip.AddrPort
|
||||||
for _, srv := range servers {
|
for _, srv := range servers {
|
||||||
srvs = append(srvs, srv.AddrPort())
|
srvs = append(srvs, srv.AddrPort())
|
||||||
}
|
}
|
||||||
return &upstreamResolverBase{
|
u := &upstreamResolverBase{
|
||||||
domain: domain,
|
domain: domain.Domain(d),
|
||||||
upstreamServers: srvs,
|
|
||||||
cancel: func() {},
|
cancel: func() {},
|
||||||
}
|
}
|
||||||
|
u.addRace(srvs)
|
||||||
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdateDNSServer(t *testing.T) {
|
func TestUpdateDNSServer(t *testing.T) {
|
||||||
@@ -653,74 +657,8 @@ func TestDNSServerStartStop(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
|
||||||
hostManager := &mockHostConfigurator{}
|
|
||||||
server := DefaultServer{
|
|
||||||
ctx: context.Background(),
|
|
||||||
service: NewServiceViaMemory(&mocWGIface{}),
|
|
||||||
localResolver: local.NewResolver(),
|
|
||||||
handlerChain: NewHandlerChain(),
|
|
||||||
hostManager: hostManager,
|
|
||||||
currentConfig: HostDNSConfig{
|
|
||||||
Domains: []DomainConfig{
|
|
||||||
{false, "domain0", false},
|
|
||||||
{false, "domain1", false},
|
|
||||||
{false, "domain2", false},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
statusRecorder: peer.NewRecorder("mgm"),
|
|
||||||
}
|
|
||||||
|
|
||||||
var domainsUpdate string
|
|
||||||
hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error {
|
|
||||||
domains := []string{}
|
|
||||||
for _, item := range config.Domains {
|
|
||||||
if item.Disabled {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
domains = append(domains, item.Domain)
|
|
||||||
}
|
|
||||||
domainsUpdate = strings.Join(domains, ",")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
deactivate, reactivate := server.upstreamCallbacks(&nbdns.NameServerGroup{
|
|
||||||
Domains: []string{"domain1"},
|
|
||||||
NameServers: []nbdns.NameServer{
|
|
||||||
{IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53},
|
|
||||||
},
|
|
||||||
}, nil, 0)
|
|
||||||
|
|
||||||
deactivate(nil)
|
|
||||||
expected := "domain0,domain2"
|
|
||||||
domains := []string{}
|
|
||||||
for _, item := range server.currentConfig.Domains {
|
|
||||||
if item.Disabled {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
domains = append(domains, item.Domain)
|
|
||||||
}
|
|
||||||
got := strings.Join(domains, ",")
|
|
||||||
if expected != got {
|
|
||||||
t.Errorf("expected domains list: %q, got %q", expected, got)
|
|
||||||
}
|
|
||||||
|
|
||||||
reactivate()
|
|
||||||
expected = "domain0,domain1,domain2"
|
|
||||||
domains = []string{}
|
|
||||||
for _, item := range server.currentConfig.Domains {
|
|
||||||
if item.Disabled {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
domains = append(domains, item.Domain)
|
|
||||||
}
|
|
||||||
got = strings.Join(domains, ",")
|
|
||||||
if expected != got {
|
|
||||||
t.Errorf("expected domains list: %q, got %q", expected, domainsUpdate)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
||||||
|
skipUnlessAndroid(t)
|
||||||
wgIFace, err := createWgInterfaceWithBind(t)
|
wgIFace, err := createWgInterfaceWithBind(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("failed to initialize wg interface")
|
t.Fatal("failed to initialize wg interface")
|
||||||
@@ -748,6 +686,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDNSPermanent_updateUpstream(t *testing.T) {
|
func TestDNSPermanent_updateUpstream(t *testing.T) {
|
||||||
|
skipUnlessAndroid(t)
|
||||||
wgIFace, err := createWgInterfaceWithBind(t)
|
wgIFace, err := createWgInterfaceWithBind(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("failed to initialize wg interface")
|
t.Fatal("failed to initialize wg interface")
|
||||||
@@ -841,6 +780,7 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDNSPermanent_matchOnly(t *testing.T) {
|
func TestDNSPermanent_matchOnly(t *testing.T) {
|
||||||
|
skipUnlessAndroid(t)
|
||||||
wgIFace, err := createWgInterfaceWithBind(t)
|
wgIFace, err := createWgInterfaceWithBind(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("failed to initialize wg interface")
|
t.Fatal("failed to initialize wg interface")
|
||||||
@@ -913,6 +853,18 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// skipUnlessAndroid marks tests that exercise the mobile-permanent DNS path,
|
||||||
|
// which only matches a real production setup on android (NewDefaultServerPermanentUpstream
|
||||||
|
// + androidHostManager). On non-android the desktop host manager replaces it
|
||||||
|
// during Initialize and the assertion stops making sense. Skipped here until we
|
||||||
|
// have an android CI runner.
|
||||||
|
func skipUnlessAndroid(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
if runtime.GOOS != "android" {
|
||||||
|
t.Skip("requires android runner; mobile-permanent path doesn't match production on this OS")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
||||||
@@ -1065,7 +1017,6 @@ type mockHandler struct {
|
|||||||
|
|
||||||
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
|
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
|
||||||
func (m *mockHandler) Stop() {}
|
func (m *mockHandler) Stop() {}
|
||||||
func (m *mockHandler) ProbeAvailability(context.Context) {}
|
|
||||||
func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) }
|
func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) }
|
||||||
|
|
||||||
type mockService struct{}
|
type mockService struct{}
|
||||||
@@ -2085,6 +2036,598 @@ func TestLocalResolverPriorityConstants(t *testing.T) {
|
|||||||
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
|
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestBuildUpstreamHandler_MergesGroupsPerDomain verifies that multiple
|
||||||
|
// admin-defined nameserver groups targeting the same domain collapse into a
|
||||||
|
// single handler with each group preserved as a sequential inner list.
|
||||||
|
func TestBuildUpstreamHandler_MergesGroupsPerDomain(t *testing.T) {
|
||||||
|
wgInterface := &mocWGIface{}
|
||||||
|
service := NewServiceViaMemory(wgInterface)
|
||||||
|
server := &DefaultServer{
|
||||||
|
ctx: context.Background(),
|
||||||
|
wgInterface: wgInterface,
|
||||||
|
service: service,
|
||||||
|
localResolver: local.NewResolver(),
|
||||||
|
handlerChain: NewHandlerChain(),
|
||||||
|
hostManager: &noopHostConfigurator{},
|
||||||
|
dnsMuxMap: make(registeredHandlerMap),
|
||||||
|
}
|
||||||
|
|
||||||
|
groups := []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
NameServers: []nbdns.NameServer{
|
||||||
|
{IP: netip.MustParseAddr("192.0.2.1"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||||
|
},
|
||||||
|
Domains: []string{"example.com"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
NameServers: []nbdns.NameServer{
|
||||||
|
{IP: netip.MustParseAddr("192.0.2.2"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||||
|
{IP: netip.MustParseAddr("192.0.2.3"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||||
|
},
|
||||||
|
Domains: []string{"example.com"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
muxUpdates, err := server.buildUpstreamHandlerUpdate(groups)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, muxUpdates, 1, "same-domain groups should merge into one handler")
|
||||||
|
assert.Equal(t, "example.com", muxUpdates[0].domain)
|
||||||
|
assert.Equal(t, PriorityUpstream, muxUpdates[0].priority)
|
||||||
|
|
||||||
|
handler := muxUpdates[0].handler.(*upstreamResolver)
|
||||||
|
require.Len(t, handler.upstreamServers, 2, "handler should have two groups")
|
||||||
|
assert.Equal(t, upstreamRace{netip.MustParseAddrPort("192.0.2.1:53")}, handler.upstreamServers[0])
|
||||||
|
assert.Equal(t, upstreamRace{
|
||||||
|
netip.MustParseAddrPort("192.0.2.2:53"),
|
||||||
|
netip.MustParseAddrPort("192.0.2.3:53"),
|
||||||
|
}, handler.upstreamServers[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEvaluateNSGroupHealth covers the records-only verdict. The gate
|
||||||
|
// (overlay route selected-but-no-active-peer) is intentionally NOT an
|
||||||
|
// input to the evaluator anymore: the verdict drives the Enabled flag,
|
||||||
|
// which must always reflect what we actually observed. Gate-aware event
|
||||||
|
// suppression is tested separately in the projection test.
|
||||||
|
//
|
||||||
|
// Matrix per upstream: {no record, fresh Ok, fresh Fail, stale Fail,
|
||||||
|
// stale Ok, Ok newer than Fail, Fail newer than Ok}.
|
||||||
|
// Group verdict: any fresh-working → Healthy; any fresh-broken with no
|
||||||
|
// fresh-working → Unhealthy; otherwise Undecided.
|
||||||
|
func TestEvaluateNSGroupHealth(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
a := netip.MustParseAddrPort("192.0.2.1:53")
|
||||||
|
b := netip.MustParseAddrPort("192.0.2.2:53")
|
||||||
|
|
||||||
|
recentOk := UpstreamHealth{LastOk: now.Add(-2 * time.Second)}
|
||||||
|
recentFail := UpstreamHealth{LastFail: now.Add(-1 * time.Second), LastErr: "timeout"}
|
||||||
|
staleOk := UpstreamHealth{LastOk: now.Add(-10 * time.Minute)}
|
||||||
|
staleFail := UpstreamHealth{LastFail: now.Add(-10 * time.Minute), LastErr: "timeout"}
|
||||||
|
okThenFail := UpstreamHealth{
|
||||||
|
LastOk: now.Add(-10 * time.Second),
|
||||||
|
LastFail: now.Add(-1 * time.Second),
|
||||||
|
LastErr: "timeout",
|
||||||
|
}
|
||||||
|
failThenOk := UpstreamHealth{
|
||||||
|
LastOk: now.Add(-1 * time.Second),
|
||||||
|
LastFail: now.Add(-10 * time.Second),
|
||||||
|
LastErr: "timeout",
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
health map[netip.AddrPort]UpstreamHealth
|
||||||
|
servers []netip.AddrPort
|
||||||
|
wantVerdict nsGroupVerdict
|
||||||
|
wantErrSubst string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no record, undecided",
|
||||||
|
servers: []netip.AddrPort{a},
|
||||||
|
wantVerdict: nsVerdictUndecided,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "fresh success, healthy",
|
||||||
|
health: map[netip.AddrPort]UpstreamHealth{a: recentOk},
|
||||||
|
servers: []netip.AddrPort{a},
|
||||||
|
wantVerdict: nsVerdictHealthy,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "fresh failure, unhealthy",
|
||||||
|
health: map[netip.AddrPort]UpstreamHealth{a: recentFail},
|
||||||
|
servers: []netip.AddrPort{a},
|
||||||
|
wantVerdict: nsVerdictUnhealthy,
|
||||||
|
wantErrSubst: "timeout",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only stale success, undecided",
|
||||||
|
health: map[netip.AddrPort]UpstreamHealth{a: staleOk},
|
||||||
|
servers: []netip.AddrPort{a},
|
||||||
|
wantVerdict: nsVerdictUndecided,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only stale failure, undecided",
|
||||||
|
health: map[netip.AddrPort]UpstreamHealth{a: staleFail},
|
||||||
|
servers: []netip.AddrPort{a},
|
||||||
|
wantVerdict: nsVerdictUndecided,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "both fresh, fail newer, unhealthy",
|
||||||
|
health: map[netip.AddrPort]UpstreamHealth{a: okThenFail},
|
||||||
|
servers: []netip.AddrPort{a},
|
||||||
|
wantVerdict: nsVerdictUnhealthy,
|
||||||
|
wantErrSubst: "timeout",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "both fresh, ok newer, healthy",
|
||||||
|
health: map[netip.AddrPort]UpstreamHealth{a: failThenOk},
|
||||||
|
servers: []netip.AddrPort{a},
|
||||||
|
wantVerdict: nsVerdictHealthy,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two upstreams, one success wins",
|
||||||
|
health: map[netip.AddrPort]UpstreamHealth{
|
||||||
|
a: recentFail,
|
||||||
|
b: recentOk,
|
||||||
|
},
|
||||||
|
servers: []netip.AddrPort{a, b},
|
||||||
|
wantVerdict: nsVerdictHealthy,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two upstreams, one fail one unseen, unhealthy",
|
||||||
|
health: map[netip.AddrPort]UpstreamHealth{
|
||||||
|
a: recentFail,
|
||||||
|
},
|
||||||
|
servers: []netip.AddrPort{a, b},
|
||||||
|
wantVerdict: nsVerdictUnhealthy,
|
||||||
|
wantErrSubst: "timeout",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two upstreams, all recent failures, unhealthy",
|
||||||
|
health: map[netip.AddrPort]UpstreamHealth{
|
||||||
|
a: {LastFail: now.Add(-5 * time.Second), LastErr: "timeout"},
|
||||||
|
b: {LastFail: now.Add(-1 * time.Second), LastErr: "SERVFAIL"},
|
||||||
|
},
|
||||||
|
servers: []netip.AddrPort{a, b},
|
||||||
|
wantVerdict: nsVerdictUnhealthy,
|
||||||
|
wantErrSubst: "SERVFAIL",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
verdict, err := evaluateNSGroupHealth(tc.health, tc.servers, now)
|
||||||
|
assert.Equal(t, tc.wantVerdict, verdict, "verdict mismatch")
|
||||||
|
if tc.wantErrSubst != "" {
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), tc.wantErrSubst)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// healthStubHandler is a minimal dnsMuxMap entry that exposes a fixed
|
||||||
|
// UpstreamHealth snapshot, letting tests drive recomputeNSGroupStates
|
||||||
|
// without spinning up real handlers.
|
||||||
|
type healthStubHandler struct {
|
||||||
|
health map[netip.AddrPort]UpstreamHealth
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *healthStubHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
|
||||||
|
func (h *healthStubHandler) Stop() {}
|
||||||
|
func (h *healthStubHandler) ID() types.HandlerID { return "health-stub" }
|
||||||
|
func (h *healthStubHandler) UpstreamHealth() map[netip.AddrPort]UpstreamHealth {
|
||||||
|
return h.health
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProjection_SteadyStateIsSilent guards against duplicate events:
|
||||||
|
// while a group stays Unhealthy tick after tick, only the first
|
||||||
|
// Unhealthy transition may emit. Same for staying Healthy.
|
||||||
|
func TestProjection_SteadyStateIsSilent(t *testing.T) {
|
||||||
|
fx := newProjTestFixture(t)
|
||||||
|
|
||||||
|
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||||
|
fx.tick()
|
||||||
|
fx.expectEvent("unreachable", "first fail emits warning")
|
||||||
|
|
||||||
|
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||||
|
fx.tick()
|
||||||
|
fx.tick()
|
||||||
|
fx.expectNoEvent("staying unhealthy must not re-emit")
|
||||||
|
|
||||||
|
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||||
|
fx.tick()
|
||||||
|
fx.expectEvent("recovered", "recovery on transition")
|
||||||
|
|
||||||
|
fx.tick()
|
||||||
|
fx.tick()
|
||||||
|
fx.expectNoEvent("staying healthy must not re-emit")
|
||||||
|
}
|
||||||
|
|
||||||
|
// projTestFixture is the common setup for the projection tests: a
|
||||||
|
// single-upstream group whose route classification the test can flip by
|
||||||
|
// assigning to selected/active. Callers drive failures/successes by
|
||||||
|
// mutating stub.health and calling refreshHealth.
|
||||||
|
type projTestFixture struct {
|
||||||
|
t *testing.T
|
||||||
|
recorder *peer.Status
|
||||||
|
events <-chan *proto.SystemEvent
|
||||||
|
server *DefaultServer
|
||||||
|
stub *healthStubHandler
|
||||||
|
group *nbdns.NameServerGroup
|
||||||
|
srv netip.AddrPort
|
||||||
|
selected route.HAMap
|
||||||
|
active route.HAMap
|
||||||
|
}
|
||||||
|
|
||||||
|
func newProjTestFixture(t *testing.T) *projTestFixture {
|
||||||
|
t.Helper()
|
||||||
|
recorder := peer.NewRecorder("mgm")
|
||||||
|
sub := recorder.SubscribeToEvents()
|
||||||
|
t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) })
|
||||||
|
|
||||||
|
srv := netip.MustParseAddrPort("100.64.0.1:53")
|
||||||
|
fx := &projTestFixture{
|
||||||
|
t: t,
|
||||||
|
recorder: recorder,
|
||||||
|
events: sub.Events(),
|
||||||
|
stub: &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{}},
|
||||||
|
srv: srv,
|
||||||
|
group: &nbdns.NameServerGroup{
|
||||||
|
Domains: []string{"example.com"},
|
||||||
|
NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
fx.server = &DefaultServer{
|
||||||
|
ctx: context.Background(),
|
||||||
|
wgInterface: &mocWGIface{},
|
||||||
|
statusRecorder: recorder,
|
||||||
|
dnsMuxMap: make(registeredHandlerMap),
|
||||||
|
selectedRoutes: func() route.HAMap { return fx.selected },
|
||||||
|
activeRoutes: func() route.HAMap { return fx.active },
|
||||||
|
warningDelayBase: defaultWarningDelayBase,
|
||||||
|
}
|
||||||
|
fx.server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: fx.stub, priority: PriorityUpstream}
|
||||||
|
|
||||||
|
fx.server.mux.Lock()
|
||||||
|
fx.server.updateNSGroupStates([]*nbdns.NameServerGroup{fx.group})
|
||||||
|
fx.server.mux.Unlock()
|
||||||
|
return fx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *projTestFixture) setHealth(h UpstreamHealth) {
|
||||||
|
f.stub.health = map[netip.AddrPort]UpstreamHealth{f.srv: h}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *projTestFixture) tick() []peer.NSGroupState {
|
||||||
|
f.server.refreshHealth()
|
||||||
|
return f.recorder.GetDNSStates()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *projTestFixture) expectNoEvent(why string) {
|
||||||
|
f.t.Helper()
|
||||||
|
select {
|
||||||
|
case evt := <-f.events:
|
||||||
|
f.t.Fatalf("unexpected event (%s): %+v", why, evt)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *projTestFixture) expectEvent(substr, why string) *proto.SystemEvent {
|
||||||
|
f.t.Helper()
|
||||||
|
select {
|
||||||
|
case evt := <-f.events:
|
||||||
|
assert.Contains(f.t, evt.Message, substr, why)
|
||||||
|
return evt
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
f.t.Fatalf("expected event (%s) with %q", why, substr)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var overlayNetForTest = netip.MustParsePrefix("100.64.0.0/16")
|
||||||
|
var overlayMapForTest = route.HAMap{"overlay": {{Network: overlayNetForTest}}}
|
||||||
|
|
||||||
|
// TestProjection_PublicFailEmitsImmediately covers rule 1: an upstream
|
||||||
|
// that is not inside any selected route (public DNS) fires the warning
|
||||||
|
// on the first Unhealthy tick, no grace period.
|
||||||
|
func TestProjection_PublicFailEmitsImmediately(t *testing.T) {
|
||||||
|
fx := newProjTestFixture(t)
|
||||||
|
|
||||||
|
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||||
|
states := fx.tick()
|
||||||
|
require.Len(t, states, 1)
|
||||||
|
assert.False(t, states[0].Enabled)
|
||||||
|
fx.expectEvent("unreachable", "public DNS failure")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProjection_OverlayConnectedFailEmitsImmediately covers rule 2:
|
||||||
|
// the upstream is inside a selected route AND the route has a Connected
|
||||||
|
// peer. Tunnel is up, failure is real, emit immediately.
|
||||||
|
func TestProjection_OverlayConnectedFailEmitsImmediately(t *testing.T) {
|
||||||
|
fx := newProjTestFixture(t)
|
||||||
|
fx.selected = overlayMapForTest
|
||||||
|
fx.active = overlayMapForTest
|
||||||
|
|
||||||
|
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||||
|
states := fx.tick()
|
||||||
|
require.Len(t, states, 1)
|
||||||
|
assert.False(t, states[0].Enabled)
|
||||||
|
fx.expectEvent("unreachable", "overlay + connected failure")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProjection_OverlayNotConnectedDelaysWarning covers rule 3: the
|
||||||
|
// upstream is routed but no peer is Connected (Connecting/Idle/missing).
|
||||||
|
// First tick: Unhealthy display, no warning. After the grace window
|
||||||
|
// elapses with no recovery, the warning fires.
|
||||||
|
func TestProjection_OverlayNotConnectedDelaysWarning(t *testing.T) {
|
||||||
|
grace := 50 * time.Millisecond
|
||||||
|
fx := newProjTestFixture(t)
|
||||||
|
fx.server.warningDelayBase = grace
|
||||||
|
fx.selected = overlayMapForTest
|
||||||
|
// active stays nil: routed but not connected.
|
||||||
|
|
||||||
|
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||||
|
states := fx.tick()
|
||||||
|
require.Len(t, states, 1)
|
||||||
|
assert.False(t, states[0].Enabled, "display must reflect failure even during grace window")
|
||||||
|
fx.expectNoEvent("first fail tick within grace window")
|
||||||
|
|
||||||
|
time.Sleep(grace + 10*time.Millisecond)
|
||||||
|
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||||
|
fx.tick()
|
||||||
|
fx.expectEvent("unreachable", "warning after grace window")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProjection_OverlayAddrNoRouteDelaysWarning covers an upstream
|
||||||
|
// whose address is inside the WireGuard overlay range but is not
|
||||||
|
// covered by any selected route (peer-to-peer DNS without an explicit
|
||||||
|
// route). Until a peer reports Connected for that address, startup
|
||||||
|
// failures must be held just like the routed case.
|
||||||
|
func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) {
|
||||||
|
recorder := peer.NewRecorder("mgm")
|
||||||
|
sub := recorder.SubscribeToEvents()
|
||||||
|
t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) })
|
||||||
|
|
||||||
|
overlayPeer := netip.MustParseAddrPort("100.66.100.5:53")
|
||||||
|
server := &DefaultServer{
|
||||||
|
ctx: context.Background(),
|
||||||
|
wgInterface: &mocWGIface{},
|
||||||
|
statusRecorder: recorder,
|
||||||
|
dnsMuxMap: make(registeredHandlerMap),
|
||||||
|
selectedRoutes: func() route.HAMap { return nil },
|
||||||
|
activeRoutes: func() route.HAMap { return nil },
|
||||||
|
warningDelayBase: 50 * time.Millisecond,
|
||||||
|
}
|
||||||
|
group := &nbdns.NameServerGroup{
|
||||||
|
Domains: []string{"example.com"},
|
||||||
|
NameServers: []nbdns.NameServer{{IP: overlayPeer.Addr(), NSType: nbdns.UDPNameServerType, Port: int(overlayPeer.Port())}},
|
||||||
|
}
|
||||||
|
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{
|
||||||
|
overlayPeer: {LastFail: time.Now(), LastErr: "timeout"},
|
||||||
|
}}
|
||||||
|
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
|
||||||
|
|
||||||
|
server.mux.Lock()
|
||||||
|
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
||||||
|
server.mux.Unlock()
|
||||||
|
server.refreshHealth()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case evt := <-sub.Events():
|
||||||
|
t.Fatalf("unexpected event during grace window: %+v", evt)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(60 * time.Millisecond)
|
||||||
|
stub.health = map[netip.AddrPort]UpstreamHealth{overlayPeer: {LastFail: time.Now(), LastErr: "timeout"}}
|
||||||
|
server.refreshHealth()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case evt := <-sub.Events():
|
||||||
|
assert.Contains(t, evt.Message, "unreachable")
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("expected warning after grace window")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProjection_StopClearsHealthState verifies that Stop wipes the
|
||||||
|
// per-group projection state so a subsequent Start doesn't inherit
|
||||||
|
// sticky flags (notably everHealthy) that would bypass the grace
|
||||||
|
// window during the next peer handshake.
|
||||||
|
func TestProjection_StopClearsHealthState(t *testing.T) {
|
||||||
|
wgIface := &mocWGIface{}
|
||||||
|
server := &DefaultServer{
|
||||||
|
ctx: context.Background(),
|
||||||
|
wgInterface: wgIface,
|
||||||
|
service: NewServiceViaMemory(wgIface),
|
||||||
|
hostManager: &noopHostConfigurator{},
|
||||||
|
extraDomains: map[domain.Domain]int{},
|
||||||
|
dnsMuxMap: make(registeredHandlerMap),
|
||||||
|
statusRecorder: peer.NewRecorder("mgm"),
|
||||||
|
selectedRoutes: func() route.HAMap { return nil },
|
||||||
|
activeRoutes: func() route.HAMap { return nil },
|
||||||
|
warningDelayBase: defaultWarningDelayBase,
|
||||||
|
currentConfigHash: ^uint64(0),
|
||||||
|
}
|
||||||
|
server.ctx, server.ctxCancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
srv := netip.MustParseAddrPort("8.8.8.8:53")
|
||||||
|
group := &nbdns.NameServerGroup{
|
||||||
|
Domains: []string{"example.com"},
|
||||||
|
NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}},
|
||||||
|
}
|
||||||
|
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{srv: {LastOk: time.Now()}}}
|
||||||
|
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
|
||||||
|
|
||||||
|
server.mux.Lock()
|
||||||
|
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
||||||
|
server.mux.Unlock()
|
||||||
|
server.refreshHealth()
|
||||||
|
|
||||||
|
server.healthProjectMu.Lock()
|
||||||
|
p, ok := server.nsGroupProj[generateGroupKey(group)]
|
||||||
|
server.healthProjectMu.Unlock()
|
||||||
|
require.True(t, ok, "projection state should exist after tick")
|
||||||
|
require.True(t, p.everHealthy, "tick with success must set everHealthy")
|
||||||
|
|
||||||
|
server.Stop()
|
||||||
|
|
||||||
|
server.healthProjectMu.Lock()
|
||||||
|
cleared := server.nsGroupProj == nil
|
||||||
|
server.healthProjectMu.Unlock()
|
||||||
|
assert.True(t, cleared, "Stop must clear nsGroupProj")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProjection_OverlayRecoversDuringGrace covers the happy path of
|
||||||
|
// rule 3: startup failures while the peer is handshaking, then the peer
|
||||||
|
// comes up and a query succeeds before the grace window elapses. No
|
||||||
|
// warning should ever have fired, and no recovery either.
|
||||||
|
func TestProjection_OverlayRecoversDuringGrace(t *testing.T) {
|
||||||
|
fx := newProjTestFixture(t)
|
||||||
|
fx.server.warningDelayBase = 200 * time.Millisecond
|
||||||
|
fx.selected = overlayMapForTest
|
||||||
|
|
||||||
|
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||||
|
fx.tick()
|
||||||
|
fx.expectNoEvent("fail within grace, warning suppressed")
|
||||||
|
|
||||||
|
fx.active = overlayMapForTest
|
||||||
|
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||||
|
states := fx.tick()
|
||||||
|
require.Len(t, states, 1)
|
||||||
|
assert.True(t, states[0].Enabled)
|
||||||
|
fx.expectNoEvent("recovery without prior warning must not emit")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProjection_RecoveryOnlyAfterWarning enforces the invariant the
|
||||||
|
// whole design leans on: recovery events only appear when a warning
|
||||||
|
// event was actually emitted for the current streak. A Healthy verdict
|
||||||
|
// without a prior warning is silent, so the user never sees "recovered"
|
||||||
|
// out of thin air.
|
||||||
|
func TestProjection_RecoveryOnlyAfterWarning(t *testing.T) {
|
||||||
|
fx := newProjTestFixture(t)
|
||||||
|
|
||||||
|
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||||
|
states := fx.tick()
|
||||||
|
require.Len(t, states, 1)
|
||||||
|
assert.True(t, states[0].Enabled)
|
||||||
|
fx.expectNoEvent("first healthy tick should not recover anything")
|
||||||
|
|
||||||
|
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||||
|
fx.tick()
|
||||||
|
fx.expectEvent("unreachable", "public fail emits immediately")
|
||||||
|
|
||||||
|
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||||
|
fx.tick()
|
||||||
|
fx.expectEvent("recovered", "recovery follows real warning")
|
||||||
|
|
||||||
|
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||||
|
fx.tick()
|
||||||
|
fx.expectEvent("unreachable", "second cycle warning")
|
||||||
|
|
||||||
|
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||||
|
fx.tick()
|
||||||
|
fx.expectEvent("recovered", "second cycle recovery")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProjection_EverHealthyOverridesDelay covers rule 4: once a group
|
||||||
|
// has ever been Healthy, subsequent failures skip the grace window even
|
||||||
|
// if classification says "routed + not connected". The system has
|
||||||
|
// proved it can work, so any new failure is real.
|
||||||
|
func TestProjection_EverHealthyOverridesDelay(t *testing.T) {
|
||||||
|
fx := newProjTestFixture(t)
|
||||||
|
// Large base so any emission must come from the everHealthy bypass, not elapsed time.
|
||||||
|
fx.server.warningDelayBase = time.Hour
|
||||||
|
fx.selected = overlayMapForTest
|
||||||
|
fx.active = overlayMapForTest
|
||||||
|
|
||||||
|
// Establish "ever healthy".
|
||||||
|
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||||
|
fx.tick()
|
||||||
|
fx.expectNoEvent("first healthy tick")
|
||||||
|
|
||||||
|
// Peer drops. Query fails. Routed + not connected → normally grace,
|
||||||
|
// but everHealthy flag bypasses it.
|
||||||
|
fx.active = nil
|
||||||
|
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||||
|
fx.tick()
|
||||||
|
fx.expectEvent("unreachable", "failure after ever-healthy must be immediate")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProjection_ReconnectBlipEmitsPair covers the explicit tradeoff
|
||||||
|
// from the design discussion: once a group has been healthy, a brief
|
||||||
|
// reconnect that produces a failing tick will fire warning + recovery.
|
||||||
|
// This is by design: user-visible blips are accurate signal, not noise.
|
||||||
|
func TestProjection_ReconnectBlipEmitsPair(t *testing.T) {
|
||||||
|
fx := newProjTestFixture(t)
|
||||||
|
fx.selected = overlayMapForTest
|
||||||
|
fx.active = overlayMapForTest
|
||||||
|
|
||||||
|
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||||
|
fx.tick()
|
||||||
|
|
||||||
|
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||||
|
fx.tick()
|
||||||
|
fx.expectEvent("unreachable", "blip warning")
|
||||||
|
|
||||||
|
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||||
|
fx.tick()
|
||||||
|
fx.expectEvent("recovered", "blip recovery")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProjection_MixedGroupEmitsImmediately covers the multi-upstream
|
||||||
|
// rule: a group with at least one public upstream is in the "immediate"
|
||||||
|
// category regardless of the other upstreams' routing, because the
|
||||||
|
// public one has no peer-startup excuse. Prevents public-DNS failures
|
||||||
|
// from being hidden behind a routed sibling.
|
||||||
|
func TestProjection_MixedGroupEmitsImmediately(t *testing.T) {
|
||||||
|
recorder := peer.NewRecorder("mgm")
|
||||||
|
sub := recorder.SubscribeToEvents()
|
||||||
|
t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) })
|
||||||
|
events := sub.Events()
|
||||||
|
|
||||||
|
public := netip.MustParseAddrPort("8.8.8.8:53")
|
||||||
|
overlay := netip.MustParseAddrPort("100.64.0.1:53")
|
||||||
|
overlayMap := route.HAMap{"overlay": {{Network: netip.MustParsePrefix("100.64.0.0/16")}}}
|
||||||
|
|
||||||
|
server := &DefaultServer{
|
||||||
|
ctx: context.Background(),
|
||||||
|
statusRecorder: recorder,
|
||||||
|
dnsMuxMap: make(registeredHandlerMap),
|
||||||
|
selectedRoutes: func() route.HAMap { return overlayMap },
|
||||||
|
activeRoutes: func() route.HAMap { return nil },
|
||||||
|
warningDelayBase: time.Hour,
|
||||||
|
}
|
||||||
|
group := &nbdns.NameServerGroup{
|
||||||
|
Domains: []string{"example.com"},
|
||||||
|
NameServers: []nbdns.NameServer{
|
||||||
|
{IP: public.Addr(), NSType: nbdns.UDPNameServerType, Port: int(public.Port())},
|
||||||
|
{IP: overlay.Addr(), NSType: nbdns.UDPNameServerType, Port: int(overlay.Port())},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
stub := &healthStubHandler{
|
||||||
|
health: map[netip.AddrPort]UpstreamHealth{
|
||||||
|
public: {LastFail: time.Now(), LastErr: "servfail"},
|
||||||
|
overlay: {LastFail: time.Now(), LastErr: "timeout"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
|
||||||
|
|
||||||
|
server.mux.Lock()
|
||||||
|
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
||||||
|
server.mux.Unlock()
|
||||||
|
server.refreshHealth()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case evt := <-events:
|
||||||
|
assert.Contains(t, evt.Message, "unreachable")
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("expected immediate warning because group contains a public upstream")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestDNSLoopPrevention(t *testing.T) {
|
func TestDNSLoopPrevention(t *testing.T) {
|
||||||
wgInterface := &mocWGIface{}
|
wgInterface := &mocWGIface{}
|
||||||
service := NewServiceViaMemory(wgInterface)
|
service := NewServiceViaMemory(wgInterface)
|
||||||
@@ -2183,17 +2726,18 @@ func TestDNSLoopPrevention(t *testing.T) {
|
|||||||
|
|
||||||
if tt.expectedHandlers > 0 {
|
if tt.expectedHandlers > 0 {
|
||||||
handler := muxUpdates[0].handler.(*upstreamResolver)
|
handler := muxUpdates[0].handler.(*upstreamResolver)
|
||||||
assert.Len(t, handler.upstreamServers, len(tt.expectedServers))
|
flat := handler.flatUpstreams()
|
||||||
|
assert.Len(t, flat, len(tt.expectedServers))
|
||||||
|
|
||||||
if tt.shouldFilterOwnIP {
|
if tt.shouldFilterOwnIP {
|
||||||
for _, upstream := range handler.upstreamServers {
|
for _, upstream := range flat {
|
||||||
assert.NotEqual(t, dnsServerIP, upstream.Addr())
|
assert.NotEqual(t, dnsServerIP, upstream.Addr())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, expected := range tt.expectedServers {
|
for _, expected := range tt.expectedServers {
|
||||||
found := false
|
found := false
|
||||||
for _, upstream := range handler.upstreamServers {
|
for _, upstream := range flat {
|
||||||
if upstream.Addr() == expected {
|
if upstream.Addr() == expected {
|
||||||
found = true
|
found = true
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/godbus/dbus/v5"
|
"github.com/godbus/dbus/v5"
|
||||||
@@ -42,8 +43,15 @@ const (
|
|||||||
type systemdDbusConfigurator struct {
|
type systemdDbusConfigurator struct {
|
||||||
dbusLinkObject dbus.ObjectPath
|
dbusLinkObject dbus.ObjectPath
|
||||||
ifaceName string
|
ifaceName string
|
||||||
|
wgIndex int
|
||||||
|
origNameservers []netip.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
systemdDbusLinkDNSProperty = systemdDbusLinkInterface + ".DNS"
|
||||||
|
systemdDbusLinkDefaultRouteProperty = systemdDbusLinkInterface + ".DefaultRoute"
|
||||||
|
)
|
||||||
|
|
||||||
// the types below are based on dbus specification, each field is mapped to a dbus type
|
// the types below are based on dbus specification, each field is mapped to a dbus type
|
||||||
// see https://dbus.freedesktop.org/doc/dbus-specification.html#basic-types for more details on dbus types
|
// see https://dbus.freedesktop.org/doc/dbus-specification.html#basic-types for more details on dbus types
|
||||||
// see https://www.freedesktop.org/software/systemd/man/org.freedesktop.resolve1.html on resolve1 input types
|
// see https://www.freedesktop.org/software/systemd/man/org.freedesktop.resolve1.html on resolve1 input types
|
||||||
@@ -79,10 +87,145 @@ func newSystemdDbusConfigurator(wgInterface string) (*systemdDbusConfigurator, e
|
|||||||
|
|
||||||
log.Debugf("got dbus Link interface: %s from net interface %s and index %d", s, iface.Name, iface.Index)
|
log.Debugf("got dbus Link interface: %s from net interface %s and index %d", s, iface.Name, iface.Index)
|
||||||
|
|
||||||
return &systemdDbusConfigurator{
|
c := &systemdDbusConfigurator{
|
||||||
dbusLinkObject: dbus.ObjectPath(s),
|
dbusLinkObject: dbus.ObjectPath(s),
|
||||||
ifaceName: wgInterface,
|
ifaceName: wgInterface,
|
||||||
}, nil
|
wgIndex: iface.Index,
|
||||||
|
}
|
||||||
|
|
||||||
|
origNameservers, err := c.captureOriginalNameservers()
|
||||||
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
log.Warnf("capture original nameservers from systemd-resolved: %v", err)
|
||||||
|
case len(origNameservers) == 0:
|
||||||
|
log.Warnf("no original nameservers captured from systemd-resolved default-route links; DNS fallback will be empty")
|
||||||
|
default:
|
||||||
|
log.Debugf("captured %d original nameservers from systemd-resolved default-route links: %v", len(origNameservers), origNameservers)
|
||||||
|
}
|
||||||
|
c.origNameservers = origNameservers
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// captureOriginalNameservers reads per-link DNS from systemd-resolved for
|
||||||
|
// every default-route link except our own WG link. Non-default-route links
|
||||||
|
// (VPNs, docker bridges) are skipped because their upstreams wouldn't
|
||||||
|
// actually serve host queries.
|
||||||
|
func (s *systemdDbusConfigurator) captureOriginalNameservers() ([]netip.Addr, error) {
|
||||||
|
ifaces, err := net.Interfaces()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("list interfaces: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
seen := make(map[netip.Addr]struct{})
|
||||||
|
var out []netip.Addr
|
||||||
|
for _, iface := range ifaces {
|
||||||
|
if !s.isCandidateLink(iface) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
linkPath, err := getSystemdLinkPath(iface.Index)
|
||||||
|
if err != nil || !isSystemdLinkDefaultRoute(linkPath) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, addr := range readSystemdLinkDNS(linkPath) {
|
||||||
|
addr = normalizeSystemdAddr(addr, iface.Name)
|
||||||
|
if !addr.IsValid() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, dup := seen[addr]; dup {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[addr] = struct{}{}
|
||||||
|
out = append(out, addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *systemdDbusConfigurator) isCandidateLink(iface net.Interface) bool {
|
||||||
|
if iface.Index == s.wgIndex {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizeSystemdAddr unmaps v4-mapped-v6, drops unspecified, and reattaches
|
||||||
|
// the link's iface name as zone for link-local v6 (Link.DNS strips it).
|
||||||
|
// Returns the zero Addr to signal "skip this entry".
|
||||||
|
func normalizeSystemdAddr(addr netip.Addr, ifaceName string) netip.Addr {
|
||||||
|
addr = addr.Unmap()
|
||||||
|
if !addr.IsValid() || addr.IsUnspecified() {
|
||||||
|
return netip.Addr{}
|
||||||
|
}
|
||||||
|
if addr.IsLinkLocalUnicast() {
|
||||||
|
return addr.WithZone(ifaceName)
|
||||||
|
}
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func getSystemdLinkPath(ifIndex int) (dbus.ObjectPath, error) {
|
||||||
|
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("dbus resolve1: %w", err)
|
||||||
|
}
|
||||||
|
defer closeConn()
|
||||||
|
var p string
|
||||||
|
if err := obj.Call(systemdDbusGetLinkMethod, dbusDefaultFlag, int32(ifIndex)).Store(&p); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return dbus.ObjectPath(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSystemdLinkDefaultRoute(linkPath dbus.ObjectPath) bool {
|
||||||
|
obj, closeConn, err := getDbusObject(systemdResolvedDest, linkPath)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
defer closeConn()
|
||||||
|
v, err := obj.GetProperty(systemdDbusLinkDefaultRouteProperty)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
b, ok := v.Value().(bool)
|
||||||
|
return ok && b
|
||||||
|
}
|
||||||
|
|
||||||
|
func readSystemdLinkDNS(linkPath dbus.ObjectPath) []netip.Addr {
|
||||||
|
obj, closeConn, err := getDbusObject(systemdResolvedDest, linkPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
defer closeConn()
|
||||||
|
v, err := obj.GetProperty(systemdDbusLinkDNSProperty)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
entries, ok := v.Value().([][]any)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var out []netip.Addr
|
||||||
|
for _, entry := range entries {
|
||||||
|
if len(entry) < 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
raw, ok := entry[1].([]byte)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
addr, ok := netip.AddrFromSlice(raw)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, addr)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *systemdDbusConfigurator) getOriginalNameservers() []netip.Addr {
|
||||||
|
return slices.Clone(s.origNameservers)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemdDbusConfigurator) supportCustomPort() bool {
|
func (s *systemdDbusConfigurator) supportCustomPort() bool {
|
||||||
|
|||||||
@@ -1,3 +1,32 @@
|
|||||||
|
// Package dns implements the client-side DNS stack: listener/service on the
|
||||||
|
// peer's tunnel address, handler chain that routes questions by domain and
|
||||||
|
// priority, and upstream resolvers that forward what remains to configured
|
||||||
|
// nameservers.
|
||||||
|
//
|
||||||
|
// # Upstream resolution and the race model
|
||||||
|
//
|
||||||
|
// When two or more nameserver groups target the same domain, DefaultServer
|
||||||
|
// merges them into one upstream handler whose state is:
|
||||||
|
//
|
||||||
|
// upstreamResolverBase
|
||||||
|
// └── upstreamServers []upstreamRace // one entry per source NS group
|
||||||
|
// └── []netip.AddrPort // primary, fallback, ...
|
||||||
|
//
|
||||||
|
// Each source nameserver group contributes one upstreamRace. Within a race
|
||||||
|
// upstreams are tried in order: the next is used only on failure (timeout,
|
||||||
|
// SERVFAIL, REFUSED, no response). NXDOMAIN is a valid answer and stops
|
||||||
|
// the walk. When more than one race exists, ServeDNS fans out one
|
||||||
|
// goroutine per race and returns the first valid answer, cancelling the
|
||||||
|
// rest. A handler with a single race skips the fan-out.
|
||||||
|
//
|
||||||
|
// # Health projection
|
||||||
|
//
|
||||||
|
// Query outcomes are recorded per-upstream in UpstreamHealth. The server
|
||||||
|
// periodically merges these snapshots across handlers and projects them
|
||||||
|
// into peer.NSGroupState. There is no active probing: a group is marked
|
||||||
|
// unhealthy only when every seen upstream has a recent failure and none
|
||||||
|
// has a recent success. Healthy→unhealthy fires a single
|
||||||
|
// SystemEvent_WARNING; steady-state refreshes do not duplicate it.
|
||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -11,11 +40,8 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
|
||||||
"github.com/hashicorp/go-multierror"
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
@@ -24,7 +50,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
var currentMTU uint16 = iface.DefaultMTU
|
var currentMTU uint16 = iface.DefaultMTU
|
||||||
@@ -39,15 +66,17 @@ const (
|
|||||||
// Set longer than UpstreamTimeout to ensure context timeout takes precedence
|
// Set longer than UpstreamTimeout to ensure context timeout takes precedence
|
||||||
ClientTimeout = 5 * time.Second
|
ClientTimeout = 5 * time.Second
|
||||||
|
|
||||||
reactivatePeriod = 30 * time.Second
|
|
||||||
probeTimeout = 2 * time.Second
|
|
||||||
|
|
||||||
// ipv6HeaderSize + udpHeaderSize, used to derive the maximum DNS UDP
|
// ipv6HeaderSize + udpHeaderSize, used to derive the maximum DNS UDP
|
||||||
// payload from the tunnel MTU.
|
// payload from the tunnel MTU.
|
||||||
ipUDPHeaderSize = 60 + 8
|
ipUDPHeaderSize = 60 + 8
|
||||||
)
|
|
||||||
|
|
||||||
const testRecord = "com."
|
// raceMaxTotalTimeout caps the combined time spent walking all upstreams
|
||||||
|
// within one race, so a slow primary can't eat the whole race budget.
|
||||||
|
raceMaxTotalTimeout = 5 * time.Second
|
||||||
|
// raceMinPerUpstreamTimeout is the floor applied when dividing
|
||||||
|
// raceMaxTotalTimeout across upstreams within a race.
|
||||||
|
raceMinPerUpstreamTimeout = 2 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
protoUDP = "udp"
|
protoUDP = "udp"
|
||||||
@@ -56,6 +85,68 @@ const (
|
|||||||
|
|
||||||
type dnsProtocolKey struct{}
|
type dnsProtocolKey struct{}
|
||||||
|
|
||||||
|
type upstreamProtocolKey struct{}
|
||||||
|
|
||||||
|
// upstreamProtocolResult holds the protocol used for the upstream exchange.
|
||||||
|
// Stored as a pointer in context so the exchange function can set it.
|
||||||
|
type upstreamProtocolResult struct {
|
||||||
|
protocol string
|
||||||
|
}
|
||||||
|
|
||||||
|
type upstreamClient interface {
|
||||||
|
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type UpstreamResolver interface {
|
||||||
|
serveDNS(r *dns.Msg) (*dns.Msg, time.Duration, error)
|
||||||
|
upstreamExchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// upstreamRace is an ordered list of upstreams derived from one configured
|
||||||
|
// nameserver group. Order matters: the first upstream is tried first, the
|
||||||
|
// second only on failure, and so on. Multiple upstreamRace values coexist
|
||||||
|
// inside one resolver when overlapping nameserver groups target the same
|
||||||
|
// domain; those races run in parallel and the first valid answer wins.
|
||||||
|
type upstreamRace []netip.AddrPort
|
||||||
|
|
||||||
|
// UpstreamHealth is the last query-path outcome for a single upstream,
|
||||||
|
// consumed by nameserver-group status projection.
|
||||||
|
type UpstreamHealth struct {
|
||||||
|
LastOk time.Time
|
||||||
|
LastFail time.Time
|
||||||
|
LastErr string
|
||||||
|
}
|
||||||
|
|
||||||
|
type upstreamResolverBase struct {
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
upstreamClient upstreamClient
|
||||||
|
upstreamServers []upstreamRace
|
||||||
|
domain domain.Domain
|
||||||
|
upstreamTimeout time.Duration
|
||||||
|
|
||||||
|
healthMu sync.RWMutex
|
||||||
|
health map[netip.AddrPort]*UpstreamHealth
|
||||||
|
|
||||||
|
statusRecorder *peer.Status
|
||||||
|
// selectedRoutes returns the current set of client routes the admin
|
||||||
|
// has enabled. Called lazily from the query hot path when an upstream
|
||||||
|
// might need a tunnel-bound client (iOS) and from health projection.
|
||||||
|
selectedRoutes func() route.HAMap
|
||||||
|
}
|
||||||
|
|
||||||
|
type upstreamFailure struct {
|
||||||
|
upstream netip.AddrPort
|
||||||
|
reason string
|
||||||
|
}
|
||||||
|
|
||||||
|
type raceResult struct {
|
||||||
|
msg *dns.Msg
|
||||||
|
upstream netip.AddrPort
|
||||||
|
protocol string
|
||||||
|
failures []upstreamFailure
|
||||||
|
}
|
||||||
|
|
||||||
// contextWithDNSProtocol stores the inbound DNS protocol ("udp" or "tcp") in context.
|
// contextWithDNSProtocol stores the inbound DNS protocol ("udp" or "tcp") in context.
|
||||||
func contextWithDNSProtocol(ctx context.Context, network string) context.Context {
|
func contextWithDNSProtocol(ctx context.Context, network string) context.Context {
|
||||||
return context.WithValue(ctx, dnsProtocolKey{}, network)
|
return context.WithValue(ctx, dnsProtocolKey{}, network)
|
||||||
@@ -72,16 +163,8 @@ func dnsProtocolFromContext(ctx context.Context) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
type upstreamProtocolKey struct{}
|
// contextWithUpstreamProtocolResult stores a mutable result holder in the context.
|
||||||
|
func contextWithUpstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
|
||||||
// upstreamProtocolResult holds the protocol used for the upstream exchange.
|
|
||||||
// Stored as a pointer in context so the exchange function can set it.
|
|
||||||
type upstreamProtocolResult struct {
|
|
||||||
protocol string
|
|
||||||
}
|
|
||||||
|
|
||||||
// contextWithupstreamProtocolResult stores a mutable result holder in the context.
|
|
||||||
func contextWithupstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
|
|
||||||
r := &upstreamProtocolResult{}
|
r := &upstreamProtocolResult{}
|
||||||
return context.WithValue(ctx, upstreamProtocolKey{}, r), r
|
return context.WithValue(ctx, upstreamProtocolKey{}, r), r
|
||||||
}
|
}
|
||||||
@@ -96,68 +179,38 @@ func setUpstreamProtocol(ctx context.Context, protocol string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type upstreamClient interface {
|
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d domain.Domain) *upstreamResolverBase {
|
||||||
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type UpstreamResolver interface {
|
|
||||||
serveDNS(r *dns.Msg) (*dns.Msg, time.Duration, error)
|
|
||||||
upstreamExchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type upstreamResolverBase struct {
|
|
||||||
ctx context.Context
|
|
||||||
cancel context.CancelFunc
|
|
||||||
upstreamClient upstreamClient
|
|
||||||
upstreamServers []netip.AddrPort
|
|
||||||
domain string
|
|
||||||
disabled bool
|
|
||||||
successCount atomic.Int32
|
|
||||||
mutex sync.Mutex
|
|
||||||
reactivatePeriod time.Duration
|
|
||||||
upstreamTimeout time.Duration
|
|
||||||
wg sync.WaitGroup
|
|
||||||
|
|
||||||
deactivate func(error)
|
|
||||||
reactivate func()
|
|
||||||
statusRecorder *peer.Status
|
|
||||||
routeMatch func(netip.Addr) bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type upstreamFailure struct {
|
|
||||||
upstream netip.AddrPort
|
|
||||||
reason string
|
|
||||||
}
|
|
||||||
|
|
||||||
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase {
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
return &upstreamResolverBase{
|
return &upstreamResolverBase{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
domain: domain,
|
domain: d,
|
||||||
upstreamTimeout: UpstreamTimeout,
|
upstreamTimeout: UpstreamTimeout,
|
||||||
reactivatePeriod: reactivatePeriod,
|
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// String returns a string representation of the upstream resolver
|
// String returns a string representation of the upstream resolver
|
||||||
func (u *upstreamResolverBase) String() string {
|
func (u *upstreamResolverBase) String() string {
|
||||||
return fmt.Sprintf("Upstream %s", u.upstreamServers)
|
return fmt.Sprintf("Upstream %s", u.flatUpstreams())
|
||||||
}
|
}
|
||||||
|
|
||||||
// ID returns the unique handler ID
|
// ID returns the unique handler ID. Race groupings and within-race
|
||||||
|
// ordering are both part of the identity: [[A,B]] and [[A],[B]] query
|
||||||
|
// the same servers but with different semantics (serial fallback vs
|
||||||
|
// parallel race), so their handlers must not collide.
|
||||||
func (u *upstreamResolverBase) ID() types.HandlerID {
|
func (u *upstreamResolverBase) ID() types.HandlerID {
|
||||||
servers := slices.Clone(u.upstreamServers)
|
|
||||||
slices.SortFunc(servers, func(a, b netip.AddrPort) int { return a.Compare(b) })
|
|
||||||
|
|
||||||
hash := sha256.New()
|
hash := sha256.New()
|
||||||
hash.Write([]byte(u.domain + ":"))
|
hash.Write([]byte(u.domain.PunycodeString() + ":"))
|
||||||
for _, s := range servers {
|
for _, race := range u.upstreamServers {
|
||||||
|
hash.Write([]byte("["))
|
||||||
|
for _, s := range race {
|
||||||
hash.Write([]byte(s.String()))
|
hash.Write([]byte(s.String()))
|
||||||
hash.Write([]byte("|"))
|
hash.Write([]byte("|"))
|
||||||
}
|
}
|
||||||
|
hash.Write([]byte("]"))
|
||||||
|
}
|
||||||
return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
|
return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -166,13 +219,31 @@ func (u *upstreamResolverBase) MatchSubdomains() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) Stop() {
|
func (u *upstreamResolverBase) Stop() {
|
||||||
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
|
log.Debugf("stopping serving DNS for upstreams %s", u.flatUpstreams())
|
||||||
u.cancel()
|
u.cancel()
|
||||||
|
}
|
||||||
|
|
||||||
u.mutex.Lock()
|
// flatUpstreams is for logging and ID hashing only, not for dispatch.
|
||||||
u.wg.Wait()
|
func (u *upstreamResolverBase) flatUpstreams() []netip.AddrPort {
|
||||||
u.mutex.Unlock()
|
var out []netip.AddrPort
|
||||||
|
for _, g := range u.upstreamServers {
|
||||||
|
out = append(out, g...)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// setSelectedRoutes swaps the accessor used to classify overlay-routed
|
||||||
|
// upstreams. Called when route sources are wired after the handler was
|
||||||
|
// built (permanent / iOS constructors).
|
||||||
|
func (u *upstreamResolverBase) setSelectedRoutes(selected func() route.HAMap) {
|
||||||
|
u.selectedRoutes = selected
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *upstreamResolverBase) addRace(servers []netip.AddrPort) {
|
||||||
|
if len(servers) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
u.upstreamServers = append(u.upstreamServers, slices.Clone(servers))
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServeDNS handles a DNS request
|
// ServeDNS handles a DNS request
|
||||||
@@ -214,59 +285,172 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
|
func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
|
||||||
timeout := u.upstreamTimeout
|
groups := u.upstreamServers
|
||||||
if len(u.upstreamServers) > 1 {
|
switch len(groups) {
|
||||||
maxTotal := 5 * time.Second
|
case 0:
|
||||||
minPerUpstream := 2 * time.Second
|
return false, nil
|
||||||
scaledTimeout := maxTotal / time.Duration(len(u.upstreamServers))
|
case 1:
|
||||||
if scaledTimeout > minPerUpstream {
|
return u.tryOnlyRace(ctx, w, r, groups[0], logger)
|
||||||
timeout = scaledTimeout
|
default:
|
||||||
} else {
|
return u.raceAll(ctx, w, r, groups, logger)
|
||||||
timeout = minPerUpstream
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *upstreamResolverBase) tryOnlyRace(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, group upstreamRace, logger *log.Entry) (bool, []upstreamFailure) {
|
||||||
|
res := u.tryRace(ctx, r, group)
|
||||||
|
if res.msg == nil {
|
||||||
|
return false, res.failures
|
||||||
|
}
|
||||||
|
u.writeSuccessResponse(w, res.msg, res.upstream, r.Question[0].Name, res.protocol, logger)
|
||||||
|
return true, res.failures
|
||||||
|
}
|
||||||
|
|
||||||
|
// raceAll runs one worker per group in parallel, taking the first valid
|
||||||
|
// answer and cancelling the rest.
|
||||||
|
func (u *upstreamResolverBase) raceAll(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, groups []upstreamRace, logger *log.Entry) (bool, []upstreamFailure) {
|
||||||
|
raceCtx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Buffer sized to len(groups) so workers never block on send, even
|
||||||
|
// after the coordinator has returned.
|
||||||
|
results := make(chan raceResult, len(groups))
|
||||||
|
for _, g := range groups {
|
||||||
|
// tryRace clones the request per attempt, so workers never share
|
||||||
|
// a *dns.Msg and concurrent EDNS0 mutations can't race.
|
||||||
|
go func(g upstreamRace) {
|
||||||
|
results <- u.tryRace(raceCtx, r, g)
|
||||||
|
}(g)
|
||||||
}
|
}
|
||||||
|
|
||||||
var failures []upstreamFailure
|
var failures []upstreamFailure
|
||||||
for _, upstream := range u.upstreamServers {
|
for range groups {
|
||||||
if failure := u.queryUpstream(ctx, w, r, upstream, timeout, logger); failure != nil {
|
select {
|
||||||
failures = append(failures, *failure)
|
case res := <-results:
|
||||||
} else {
|
failures = append(failures, res.failures...)
|
||||||
|
if res.msg != nil {
|
||||||
|
u.writeSuccessResponse(w, res.msg, res.upstream, r.Question[0].Name, res.protocol, logger)
|
||||||
return true, failures
|
return true, failures
|
||||||
}
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
return false, failures
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return false, failures
|
return false, failures
|
||||||
}
|
}
|
||||||
|
|
||||||
// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream.
|
func (u *upstreamResolverBase) tryRace(ctx context.Context, r *dns.Msg, group upstreamRace) raceResult {
|
||||||
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
|
timeout := u.upstreamTimeout
|
||||||
var rm *dns.Msg
|
if len(group) > 1 {
|
||||||
var t time.Duration
|
// Cap the whole walk at raceMaxTotalTimeout: per-upstream timeouts
|
||||||
var err error
|
// still honor raceMinPerUpstreamTimeout as a floor for correctness
|
||||||
|
// on slow links, but the outer context ensures the combined walk
|
||||||
|
// cannot exceed the cap regardless of group size.
|
||||||
|
timeout = max(raceMaxTotalTimeout/time.Duration(len(group)), raceMinPerUpstreamTimeout)
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
ctx, cancel = context.WithTimeout(ctx, raceMaxTotalTimeout)
|
||||||
|
defer cancel()
|
||||||
|
}
|
||||||
|
|
||||||
var startTime time.Time
|
var failures []upstreamFailure
|
||||||
var upstreamProto *upstreamProtocolResult
|
for _, upstream := range group {
|
||||||
func() {
|
if ctx.Err() != nil {
|
||||||
|
return raceResult{failures: failures}
|
||||||
|
}
|
||||||
|
// Clone the request per attempt: the exchange path mutates EDNS0
|
||||||
|
// options in-place, so reusing the same *dns.Msg across sequential
|
||||||
|
// upstreams would carry those mutations (e.g. a reduced UDP size)
|
||||||
|
// into the next attempt.
|
||||||
|
msg, proto, failure := u.queryUpstream(ctx, r.Copy(), upstream, timeout)
|
||||||
|
if failure != nil {
|
||||||
|
failures = append(failures, *failure)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return raceResult{msg: msg, upstream: upstream, protocol: proto, failures: failures}
|
||||||
|
}
|
||||||
|
return raceResult{failures: failures}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration) (*dns.Msg, string, *upstreamFailure) {
|
||||||
ctx, cancel := context.WithTimeout(parentCtx, timeout)
|
ctx, cancel := context.WithTimeout(parentCtx, timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
ctx, upstreamProto = contextWithupstreamProtocolResult(ctx)
|
ctx, upstreamProto := contextWithUpstreamProtocolResult(ctx)
|
||||||
startTime = time.Now()
|
|
||||||
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
|
startTime := time.Now()
|
||||||
}()
|
rm, _, err := u.upstreamClient.exchange(ctx, upstream.String(), r)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return u.handleUpstreamError(err, upstream, startTime)
|
// A parent cancellation (e.g., another race won and the coordinator
|
||||||
|
// cancelled the losers) is not an upstream failure. Check both the
|
||||||
|
// error chain and the parent context: a transport may surface the
|
||||||
|
// cancellation as a read/deadline error rather than context.Canceled.
|
||||||
|
if errors.Is(err, context.Canceled) || errors.Is(parentCtx.Err(), context.Canceled) {
|
||||||
|
return nil, "", &upstreamFailure{upstream: upstream, reason: "canceled"}
|
||||||
|
}
|
||||||
|
failure := u.handleUpstreamError(err, upstream, startTime)
|
||||||
|
u.markUpstreamFail(upstream, failure.reason)
|
||||||
|
return nil, "", failure
|
||||||
}
|
}
|
||||||
|
|
||||||
if rm == nil || !rm.Response {
|
if rm == nil || !rm.Response {
|
||||||
return &upstreamFailure{upstream: upstream, reason: "no response"}
|
u.markUpstreamFail(upstream, "no response")
|
||||||
|
return nil, "", &upstreamFailure{upstream: upstream, reason: "no response"}
|
||||||
}
|
}
|
||||||
|
|
||||||
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
|
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
|
||||||
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
|
reason := dns.RcodeToString[rm.Rcode]
|
||||||
|
u.markUpstreamFail(upstream, reason)
|
||||||
|
return nil, "", &upstreamFailure{upstream: upstream, reason: reason}
|
||||||
}
|
}
|
||||||
|
|
||||||
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
|
u.markUpstreamOk(upstream)
|
||||||
return nil
|
|
||||||
|
proto := ""
|
||||||
|
if upstreamProto != nil {
|
||||||
|
proto = upstreamProto.protocol
|
||||||
|
}
|
||||||
|
return rm, proto, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// healthEntry returns the mutable health record for addr, lazily creating
|
||||||
|
// the map and the entry. Caller must hold u.healthMu.
|
||||||
|
func (u *upstreamResolverBase) healthEntry(addr netip.AddrPort) *UpstreamHealth {
|
||||||
|
if u.health == nil {
|
||||||
|
u.health = make(map[netip.AddrPort]*UpstreamHealth)
|
||||||
|
}
|
||||||
|
h := u.health[addr]
|
||||||
|
if h == nil {
|
||||||
|
h = &UpstreamHealth{}
|
||||||
|
u.health[addr] = h
|
||||||
|
}
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *upstreamResolverBase) markUpstreamOk(addr netip.AddrPort) {
|
||||||
|
u.healthMu.Lock()
|
||||||
|
defer u.healthMu.Unlock()
|
||||||
|
h := u.healthEntry(addr)
|
||||||
|
h.LastOk = time.Now()
|
||||||
|
h.LastFail = time.Time{}
|
||||||
|
h.LastErr = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *upstreamResolverBase) markUpstreamFail(addr netip.AddrPort, reason string) {
|
||||||
|
u.healthMu.Lock()
|
||||||
|
defer u.healthMu.Unlock()
|
||||||
|
h := u.healthEntry(addr)
|
||||||
|
h.LastFail = time.Now()
|
||||||
|
h.LastErr = reason
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamHealth returns a snapshot of per-upstream query outcomes.
|
||||||
|
func (u *upstreamResolverBase) UpstreamHealth() map[netip.AddrPort]UpstreamHealth {
|
||||||
|
u.healthMu.RLock()
|
||||||
|
defer u.healthMu.RUnlock()
|
||||||
|
out := make(map[netip.AddrPort]UpstreamHealth, len(u.health))
|
||||||
|
for k, v := range u.health {
|
||||||
|
out[k] = *v
|
||||||
|
}
|
||||||
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
|
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
|
||||||
@@ -282,12 +466,23 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
|
|||||||
return &upstreamFailure{upstream: upstream, reason: reason}
|
return &upstreamFailure{upstream: upstream, reason: reason}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, upstreamProto *upstreamProtocolResult, logger *log.Entry) bool {
|
func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string {
|
||||||
u.successCount.Add(1)
|
if u.statusRecorder == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder)
|
||||||
|
if peerInfo == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, proto string, logger *log.Entry) {
|
||||||
resutil.SetMeta(w, "upstream", upstream.String())
|
resutil.SetMeta(w, "upstream", upstream.String())
|
||||||
if upstreamProto != nil && upstreamProto.protocol != "" {
|
if proto != "" {
|
||||||
resutil.SetMeta(w, "upstream_protocol", upstreamProto.protocol)
|
resutil.SetMeta(w, "upstream_protocol", proto)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clear Zero bit from external responses to prevent upstream servers from
|
// Clear Zero bit from external responses to prevent upstream servers from
|
||||||
@@ -296,14 +491,11 @@ func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dn
|
|||||||
|
|
||||||
if err := w.WriteMsg(rm); err != nil {
|
if err := w.WriteMsg(rm); err != nil {
|
||||||
logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
|
logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) logUpstreamFailures(domain string, failures []upstreamFailure, succeeded bool, logger *log.Entry) {
|
func (u *upstreamResolverBase) logUpstreamFailures(domain string, failures []upstreamFailure, succeeded bool, logger *log.Entry) {
|
||||||
totalUpstreams := len(u.upstreamServers)
|
totalUpstreams := len(u.flatUpstreams())
|
||||||
failedCount := len(failures)
|
failedCount := len(failures)
|
||||||
failureSummary := formatFailures(failures)
|
failureSummary := formatFailures(failures)
|
||||||
|
|
||||||
@@ -330,119 +522,6 @@ func formatFailures(failures []upstreamFailure) string {
|
|||||||
return strings.Join(parts, ", ")
|
return strings.Join(parts, ", ")
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProbeAvailability tests all upstream servers simultaneously and
|
|
||||||
// disables the resolver if none work
|
|
||||||
func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) {
|
|
||||||
u.mutex.Lock()
|
|
||||||
defer u.mutex.Unlock()
|
|
||||||
|
|
||||||
// avoid probe if upstreams could resolve at least one query
|
|
||||||
if u.successCount.Load() > 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var success bool
|
|
||||||
var mu sync.Mutex
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
|
|
||||||
var errs *multierror.Error
|
|
||||||
for _, upstream := range u.upstreamServers {
|
|
||||||
wg.Add(1)
|
|
||||||
go func(upstream netip.AddrPort) {
|
|
||||||
defer wg.Done()
|
|
||||||
err := u.testNameserver(u.ctx, ctx, upstream, 500*time.Millisecond)
|
|
||||||
if err != nil {
|
|
||||||
mu.Lock()
|
|
||||||
errs = multierror.Append(errs, err)
|
|
||||||
mu.Unlock()
|
|
||||||
log.Warnf("probing upstream nameserver %s: %s", upstream, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
mu.Lock()
|
|
||||||
success = true
|
|
||||||
mu.Unlock()
|
|
||||||
}(upstream)
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
case <-u.ctx.Done():
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
// didn't find a working upstream server, let's disable and try later
|
|
||||||
if !success {
|
|
||||||
u.disable(errs.ErrorOrNil())
|
|
||||||
|
|
||||||
if u.statusRecorder == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
u.statusRecorder.PublishEvent(
|
|
||||||
proto.SystemEvent_WARNING,
|
|
||||||
proto.SystemEvent_DNS,
|
|
||||||
"All upstream servers failed (probe failed)",
|
|
||||||
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
|
|
||||||
map[string]string{"upstreams": u.upstreamServersString()},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// waitUntilResponse retries, in an exponential interval, querying the upstream servers until it gets a positive response
|
|
||||||
func (u *upstreamResolverBase) waitUntilResponse() {
|
|
||||||
exponentialBackOff := &backoff.ExponentialBackOff{
|
|
||||||
InitialInterval: 500 * time.Millisecond,
|
|
||||||
RandomizationFactor: 0.5,
|
|
||||||
Multiplier: 1.1,
|
|
||||||
MaxInterval: u.reactivatePeriod,
|
|
||||||
MaxElapsedTime: 0,
|
|
||||||
Stop: backoff.Stop,
|
|
||||||
Clock: backoff.SystemClock,
|
|
||||||
}
|
|
||||||
|
|
||||||
operation := func() error {
|
|
||||||
select {
|
|
||||||
case <-u.ctx.Done():
|
|
||||||
return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServersString()))
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, upstream := range u.upstreamServers {
|
|
||||||
if err := u.testNameserver(u.ctx, nil, upstream, probeTimeout); err != nil {
|
|
||||||
log.Tracef("upstream check for %s: %s", upstream, err)
|
|
||||||
} else {
|
|
||||||
// at least one upstream server is available, stop probing
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServersString(), exponentialBackOff.NextBackOff())
|
|
||||||
return fmt.Errorf("upstream check call error")
|
|
||||||
}
|
|
||||||
|
|
||||||
err := backoff.Retry(operation, backoff.WithContext(exponentialBackOff, u.ctx))
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, context.Canceled) {
|
|
||||||
log.Debugf("upstream retry loop exited for upstreams %s", u.upstreamServersString())
|
|
||||||
} else {
|
|
||||||
log.Warnf("upstream retry loop exited for upstreams %s: %v", u.upstreamServersString(), err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString())
|
|
||||||
u.successCount.Add(1)
|
|
||||||
u.reactivate()
|
|
||||||
u.mutex.Lock()
|
|
||||||
u.disabled = false
|
|
||||||
u.mutex.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// isTimeout returns true if the given error is a network timeout error.
|
// isTimeout returns true if the given error is a network timeout error.
|
||||||
//
|
//
|
||||||
// Copied from k8s.io/apimachinery/pkg/util/net.IsTimeout
|
// Copied from k8s.io/apimachinery/pkg/util/net.IsTimeout
|
||||||
@@ -454,45 +533,6 @@ func isTimeout(err error) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) disable(err error) {
|
|
||||||
if u.disabled {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod)
|
|
||||||
u.successCount.Store(0)
|
|
||||||
u.deactivate(err)
|
|
||||||
u.disabled = true
|
|
||||||
u.wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer u.wg.Done()
|
|
||||||
u.waitUntilResponse()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *upstreamResolverBase) upstreamServersString() string {
|
|
||||||
var servers []string
|
|
||||||
for _, server := range u.upstreamServers {
|
|
||||||
servers = append(servers, server.String())
|
|
||||||
}
|
|
||||||
return strings.Join(servers, ", ")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalCtx context.Context, server netip.AddrPort, timeout time.Duration) error {
|
|
||||||
mergedCtx, cancel := context.WithTimeout(baseCtx, timeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
if externalCtx != nil {
|
|
||||||
stop2 := context.AfterFunc(externalCtx, cancel)
|
|
||||||
defer stop2()
|
|
||||||
}
|
|
||||||
|
|
||||||
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)
|
|
||||||
|
|
||||||
_, _, err := u.upstreamClient.exchange(mergedCtx, server.String(), r)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// clientUDPMaxSize returns the maximum UDP response size the client accepts.
|
// clientUDPMaxSize returns the maximum UDP response size the client accepts.
|
||||||
func clientUDPMaxSize(r *dns.Msg) int {
|
func clientUDPMaxSize(r *dns.Msg) int {
|
||||||
if opt := r.IsEdns0(); opt != nil {
|
if opt := r.IsEdns0(); opt != nil {
|
||||||
@@ -504,13 +544,10 @@ func clientUDPMaxSize(r *dns.Msg) int {
|
|||||||
// ExchangeWithFallback exchanges a DNS message with the upstream server.
|
// ExchangeWithFallback exchanges a DNS message with the upstream server.
|
||||||
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
|
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
|
||||||
// If the inbound request came over TCP (via context), it skips the UDP attempt.
|
// If the inbound request came over TCP (via context), it skips the UDP attempt.
|
||||||
// If the passed context is nil, this will use Exchange instead of ExchangeContext.
|
|
||||||
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
|
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
|
||||||
// If the request came in over TCP, go straight to TCP upstream.
|
// If the request came in over TCP, go straight to TCP upstream.
|
||||||
if dnsProtocolFromContext(ctx) == protoTCP {
|
if dnsProtocolFromContext(ctx) == protoTCP {
|
||||||
tcpClient := *client
|
rm, t, err := toTCPClient(client).ExchangeContext(ctx, r, upstream)
|
||||||
tcpClient.Net = protoTCP
|
|
||||||
rm, t, err := tcpClient.ExchangeContext(ctx, r, upstream)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, t, fmt.Errorf("with tcp: %w", err)
|
return nil, t, fmt.Errorf("with tcp: %w", err)
|
||||||
}
|
}
|
||||||
@@ -530,18 +567,7 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
|||||||
opt.SetUDPSize(maxUDPPayload)
|
opt.SetUDPSize(maxUDPPayload)
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
rm, t, err := client.ExchangeContext(ctx, r, upstream)
|
||||||
rm *dns.Msg
|
|
||||||
t time.Duration
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
|
|
||||||
if ctx == nil {
|
|
||||||
rm, t, err = client.Exchange(r, upstream)
|
|
||||||
} else {
|
|
||||||
rm, t, err = client.ExchangeContext(ctx, r, upstream)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, t, fmt.Errorf("with udp: %w", err)
|
return nil, t, fmt.Errorf("with udp: %w", err)
|
||||||
}
|
}
|
||||||
@@ -555,15 +581,7 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
|||||||
// data than the client's buffer, we could truncate locally and skip
|
// data than the client's buffer, we could truncate locally and skip
|
||||||
// the TCP retry.
|
// the TCP retry.
|
||||||
|
|
||||||
tcpClient := *client
|
rm, t, err = toTCPClient(client).ExchangeContext(ctx, r, upstream)
|
||||||
tcpClient.Net = protoTCP
|
|
||||||
|
|
||||||
if ctx == nil {
|
|
||||||
rm, t, err = tcpClient.Exchange(r, upstream)
|
|
||||||
} else {
|
|
||||||
rm, t, err = tcpClient.ExchangeContext(ctx, r, upstream)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, t, fmt.Errorf("with tcp: %w", err)
|
return nil, t, fmt.Errorf("with tcp: %w", err)
|
||||||
}
|
}
|
||||||
@@ -577,6 +595,25 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
|||||||
return rm, t, nil
|
return rm, t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// toTCPClient returns a copy of c configured for TCP. If c's Dialer has a
|
||||||
|
// *net.UDPAddr bound as LocalAddr (iOS does this to keep the source IP on
|
||||||
|
// the tunnel interface), it is converted to the equivalent *net.TCPAddr
|
||||||
|
// so net.Dialer doesn't reject the TCP dial with "mismatched local
|
||||||
|
// address type".
|
||||||
|
func toTCPClient(c *dns.Client) *dns.Client {
|
||||||
|
tcp := *c
|
||||||
|
tcp.Net = protoTCP
|
||||||
|
if tcp.Dialer == nil {
|
||||||
|
return &tcp
|
||||||
|
}
|
||||||
|
d := *tcp.Dialer
|
||||||
|
if ua, ok := d.LocalAddr.(*net.UDPAddr); ok {
|
||||||
|
d.LocalAddr = &net.TCPAddr{IP: ua.IP, Port: ua.Port, Zone: ua.Zone}
|
||||||
|
}
|
||||||
|
tcp.Dialer = &d
|
||||||
|
return &tcp
|
||||||
|
}
|
||||||
|
|
||||||
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
|
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
|
||||||
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
|
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
|
||||||
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
|
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
|
||||||
@@ -718,15 +755,36 @@ func findPeerForIP(ip netip.Addr, statusRecorder *peer.Status) *peer.State {
|
|||||||
return bestMatch
|
return bestMatch
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string {
|
// haMapRouteCount returns the total number of routes across all HA
|
||||||
if u.statusRecorder == nil {
|
// groups in the map. route.HAMap is keyed by HAUniqueID with slices of
|
||||||
return ""
|
// routes per key, so len(hm) is the number of HA groups, not routes.
|
||||||
|
func haMapRouteCount(hm route.HAMap) int {
|
||||||
|
total := 0
|
||||||
|
for _, routes := range hm {
|
||||||
|
total += len(routes)
|
||||||
}
|
}
|
||||||
|
return total
|
||||||
peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder)
|
}
|
||||||
if peerInfo == nil {
|
|
||||||
return ""
|
// haMapContains checks whether ip is covered by any concrete prefix in
|
||||||
}
|
// the HA map. haveDynamic is reported separately: dynamic (domain-based)
|
||||||
|
// routes carry a placeholder Network that can't be prefix-checked, so we
|
||||||
return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo))
|
// can't know at this point whether ip is reached through one. Callers
|
||||||
|
// decide how to interpret the unknown: health projection treats it as
|
||||||
|
// "possibly routed" to avoid emitting false-positive warnings during
|
||||||
|
// startup, while iOS dial selection requires a concrete match before
|
||||||
|
// binding to the tunnel.
|
||||||
|
func haMapContains(hm route.HAMap, ip netip.Addr) (matched, haveDynamic bool) {
|
||||||
|
for _, routes := range hm {
|
||||||
|
for _, r := range routes {
|
||||||
|
if r.IsDynamic() {
|
||||||
|
haveDynamic = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if r.Network.Contains(ip) {
|
||||||
|
return true, haveDynamic
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, haveDynamic
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
type upstreamResolver struct {
|
type upstreamResolver struct {
|
||||||
@@ -26,9 +27,9 @@ func newUpstreamResolver(
|
|||||||
_ WGIface,
|
_ WGIface,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
hostsDNSHolder *hostsDNSHolder,
|
hostsDNSHolder *hostsDNSHolder,
|
||||||
domain string,
|
d domain.Domain,
|
||||||
) (*upstreamResolver, error) {
|
) (*upstreamResolver, error) {
|
||||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d)
|
||||||
c := &upstreamResolver{
|
c := &upstreamResolver{
|
||||||
upstreamResolverBase: upstreamResolverBase,
|
upstreamResolverBase: upstreamResolverBase,
|
||||||
hostsDNSHolder: hostsDNSHolder,
|
hostsDNSHolder: hostsDNSHolder,
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
type upstreamResolver struct {
|
type upstreamResolver struct {
|
||||||
@@ -24,9 +25,9 @@ func newUpstreamResolver(
|
|||||||
wgIface WGIface,
|
wgIface WGIface,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
_ *hostsDNSHolder,
|
_ *hostsDNSHolder,
|
||||||
domain string,
|
d domain.Domain,
|
||||||
) (*upstreamResolver, error) {
|
) (*upstreamResolver, error) {
|
||||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d)
|
||||||
nonIOS := &upstreamResolver{
|
nonIOS := &upstreamResolver{
|
||||||
upstreamResolverBase: upstreamResolverBase,
|
upstreamResolverBase: upstreamResolverBase,
|
||||||
nsNet: wgIface.GetNet(),
|
nsNet: wgIface.GetNet(),
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
type upstreamResolverIOS struct {
|
type upstreamResolverIOS struct {
|
||||||
@@ -29,9 +30,9 @@ func newUpstreamResolver(
|
|||||||
wgIface WGIface,
|
wgIface WGIface,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
_ *hostsDNSHolder,
|
_ *hostsDNSHolder,
|
||||||
domain string,
|
d domain.Domain,
|
||||||
) (*upstreamResolverIOS, error) {
|
) (*upstreamResolverIOS, error) {
|
||||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d)
|
||||||
|
|
||||||
ios := &upstreamResolverIOS{
|
ios := &upstreamResolverIOS{
|
||||||
upstreamResolverBase: upstreamResolverBase,
|
upstreamResolverBase: upstreamResolverBase,
|
||||||
@@ -65,8 +66,14 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
|||||||
} else {
|
} else {
|
||||||
upstreamIP = upstreamIP.Unmap()
|
upstreamIP = upstreamIP.Unmap()
|
||||||
}
|
}
|
||||||
needsPrivate := u.lNet.Contains(upstreamIP) ||
|
var routed bool
|
||||||
(u.routeMatch != nil && u.routeMatch(upstreamIP))
|
if u.selectedRoutes != nil {
|
||||||
|
// Only a concrete prefix match binds to the tunnel: dialing
|
||||||
|
// through a private client for an upstream we can't prove is
|
||||||
|
// routed would break public resolvers.
|
||||||
|
routed, _ = haMapContains(u.selectedRoutes(), upstreamIP)
|
||||||
|
}
|
||||||
|
needsPrivate := u.lNet.Contains(upstreamIP) || routed
|
||||||
if needsPrivate {
|
if needsPrivate {
|
||||||
log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream)
|
log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream)
|
||||||
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
|
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
|
||||||
@@ -75,8 +82,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cannot use client.ExchangeContext because it overwrites our Dialer
|
return ExchangeWithFallback(ctx, client, r, upstream)
|
||||||
return ExchangeWithFallback(nil, client, r, upstream)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
|
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -73,7 +74,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
servers = append(servers, netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()))
|
servers = append(servers, netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
resolver.upstreamServers = servers
|
resolver.addRace(servers)
|
||||||
resolver.upstreamTimeout = testCase.timeout
|
resolver.upstreamTimeout = testCase.timeout
|
||||||
if testCase.cancelCTX {
|
if testCase.cancelCTX {
|
||||||
cancel()
|
cancel()
|
||||||
@@ -132,20 +133,10 @@ func (m *mockNetstackProvider) GetInterfaceGUIDString() (string, error) {
|
|||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockUpstreamResolver struct {
|
|
||||||
r *dns.Msg
|
|
||||||
rtt time.Duration
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
// exchange mock implementation of exchange from upstreamResolver
|
|
||||||
func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
|
|
||||||
return c.r, c.rtt, c.err
|
|
||||||
}
|
|
||||||
|
|
||||||
type mockUpstreamResponse struct {
|
type mockUpstreamResponse struct {
|
||||||
msg *dns.Msg
|
msg *dns.Msg
|
||||||
err error
|
err error
|
||||||
|
delay time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockUpstreamResolverPerServer struct {
|
type mockUpstreamResolverPerServer struct {
|
||||||
@@ -153,63 +144,19 @@ type mockUpstreamResolverPerServer struct {
|
|||||||
rtt time.Duration
|
rtt time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c mockUpstreamResolverPerServer) exchange(_ context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
|
func (c mockUpstreamResolverPerServer) exchange(ctx context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
|
||||||
if r, ok := c.responses[upstream]; ok {
|
r, ok := c.responses[upstream]
|
||||||
return r.msg, c.rtt, r.err
|
if !ok {
|
||||||
}
|
|
||||||
return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream)
|
return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream)
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
|
||||||
mockClient := &mockUpstreamResolver{
|
|
||||||
err: dns.ErrTime,
|
|
||||||
r: new(dns.Msg),
|
|
||||||
rtt: time.Millisecond,
|
|
||||||
}
|
}
|
||||||
|
if r.delay > 0 {
|
||||||
resolver := &upstreamResolverBase{
|
select {
|
||||||
ctx: context.TODO(),
|
case <-time.After(r.delay):
|
||||||
upstreamClient: mockClient,
|
case <-ctx.Done():
|
||||||
upstreamTimeout: UpstreamTimeout,
|
return nil, c.rtt, ctx.Err()
|
||||||
reactivatePeriod: time.Microsecond * 100,
|
|
||||||
}
|
}
|
||||||
addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection
|
|
||||||
resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())}
|
|
||||||
|
|
||||||
failed := false
|
|
||||||
resolver.deactivate = func(error) {
|
|
||||||
failed = true
|
|
||||||
// After deactivation, make the mock client work again
|
|
||||||
mockClient.err = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
reactivated := false
|
|
||||||
resolver.reactivate = func() {
|
|
||||||
reactivated = true
|
|
||||||
}
|
|
||||||
|
|
||||||
resolver.ProbeAvailability(context.TODO())
|
|
||||||
|
|
||||||
if !failed {
|
|
||||||
t.Errorf("expected that resolving was deactivated")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if !resolver.disabled {
|
|
||||||
t.Errorf("resolver should be Disabled")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Sleep(time.Millisecond * 200)
|
|
||||||
|
|
||||||
if !reactivated {
|
|
||||||
t.Errorf("expected that resolving was reactivated")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if resolver.disabled {
|
|
||||||
t.Errorf("should be enabled")
|
|
||||||
}
|
}
|
||||||
|
return r.msg, c.rtt, r.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUpstreamResolver_Failover(t *testing.T) {
|
func TestUpstreamResolver_Failover(t *testing.T) {
|
||||||
@@ -339,9 +286,9 @@ func TestUpstreamResolver_Failover(t *testing.T) {
|
|||||||
resolver := &upstreamResolverBase{
|
resolver := &upstreamResolverBase{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
upstreamClient: trackingClient,
|
upstreamClient: trackingClient,
|
||||||
upstreamServers: []netip.AddrPort{upstream1, upstream2},
|
|
||||||
upstreamTimeout: UpstreamTimeout,
|
upstreamTimeout: UpstreamTimeout,
|
||||||
}
|
}
|
||||||
|
resolver.addRace([]netip.AddrPort{upstream1, upstream2})
|
||||||
|
|
||||||
var responseMSG *dns.Msg
|
var responseMSG *dns.Msg
|
||||||
responseWriter := &test.MockResponseWriter{
|
responseWriter := &test.MockResponseWriter{
|
||||||
@@ -421,9 +368,9 @@ func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) {
|
|||||||
resolver := &upstreamResolverBase{
|
resolver := &upstreamResolverBase{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
upstreamClient: mockClient,
|
upstreamClient: mockClient,
|
||||||
upstreamServers: []netip.AddrPort{upstream},
|
|
||||||
upstreamTimeout: UpstreamTimeout,
|
upstreamTimeout: UpstreamTimeout,
|
||||||
}
|
}
|
||||||
|
resolver.addRace([]netip.AddrPort{upstream})
|
||||||
|
|
||||||
var responseMSG *dns.Msg
|
var responseMSG *dns.Msg
|
||||||
responseWriter := &test.MockResponseWriter{
|
responseWriter := &test.MockResponseWriter{
|
||||||
@@ -440,6 +387,136 @@ func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) {
|
|||||||
assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode, "single upstream SERVFAIL should return SERVFAIL")
|
assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode, "single upstream SERVFAIL should return SERVFAIL")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestUpstreamResolver_RaceAcrossGroups covers two nameserver groups
|
||||||
|
// configured for the same domain, with one broken group. The merge+race
|
||||||
|
// path should answer as fast as the working group and not pay the timeout
|
||||||
|
// of the broken one on every query.
|
||||||
|
func TestUpstreamResolver_RaceAcrossGroups(t *testing.T) {
|
||||||
|
broken := netip.MustParseAddrPort("192.0.2.1:53")
|
||||||
|
working := netip.MustParseAddrPort("192.0.2.2:53")
|
||||||
|
successAnswer := "192.0.2.100"
|
||||||
|
timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")}
|
||||||
|
|
||||||
|
mockClient := &mockUpstreamResolverPerServer{
|
||||||
|
responses: map[string]mockUpstreamResponse{
|
||||||
|
// Force the broken upstream to only unblock via timeout /
|
||||||
|
// cancellation so the assertion below can't pass if races
|
||||||
|
// were run serially.
|
||||||
|
broken.String(): {err: timeoutErr, delay: 500 * time.Millisecond},
|
||||||
|
working.String(): {msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
|
||||||
|
},
|
||||||
|
rtt: time.Millisecond,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
resolver := &upstreamResolverBase{
|
||||||
|
ctx: ctx,
|
||||||
|
upstreamClient: mockClient,
|
||||||
|
upstreamTimeout: 250 * time.Millisecond,
|
||||||
|
}
|
||||||
|
resolver.addRace([]netip.AddrPort{broken})
|
||||||
|
resolver.addRace([]netip.AddrPort{working})
|
||||||
|
|
||||||
|
var responseMSG *dns.Msg
|
||||||
|
responseWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
responseMSG = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
inputMSG := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||||
|
start := time.Now()
|
||||||
|
resolver.ServeDNS(responseWriter, inputMSG)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
require.NotNil(t, responseMSG, "should write a response")
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode)
|
||||||
|
require.NotEmpty(t, responseMSG.Answer)
|
||||||
|
assert.Contains(t, responseMSG.Answer[0].String(), successAnswer)
|
||||||
|
// Working group answers in a single RTT; the broken group's
|
||||||
|
// timeout (100ms) must not block the response.
|
||||||
|
assert.Less(t, elapsed, 100*time.Millisecond, "race must not wait for broken group's timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestUpstreamResolver_AllGroupsFail checks that when every group fails the
|
||||||
|
// resolver returns SERVFAIL rather than leaking a partial response.
|
||||||
|
func TestUpstreamResolver_AllGroupsFail(t *testing.T) {
|
||||||
|
a := netip.MustParseAddrPort("192.0.2.1:53")
|
||||||
|
b := netip.MustParseAddrPort("192.0.2.2:53")
|
||||||
|
|
||||||
|
mockClient := &mockUpstreamResolverPerServer{
|
||||||
|
responses: map[string]mockUpstreamResponse{
|
||||||
|
a.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||||
|
b.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||||
|
},
|
||||||
|
rtt: time.Millisecond,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
resolver := &upstreamResolverBase{
|
||||||
|
ctx: ctx,
|
||||||
|
upstreamClient: mockClient,
|
||||||
|
upstreamTimeout: UpstreamTimeout,
|
||||||
|
}
|
||||||
|
resolver.addRace([]netip.AddrPort{a})
|
||||||
|
resolver.addRace([]netip.AddrPort{b})
|
||||||
|
|
||||||
|
var responseMSG *dns.Msg
|
||||||
|
responseWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
responseMSG = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA))
|
||||||
|
require.NotNil(t, responseMSG)
|
||||||
|
assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestUpstreamResolver_HealthTracking verifies that query-path results are
|
||||||
|
// recorded into per-upstream health, which is what projects back to
|
||||||
|
// NSGroupState for status reporting.
|
||||||
|
func TestUpstreamResolver_HealthTracking(t *testing.T) {
|
||||||
|
ok := netip.MustParseAddrPort("192.0.2.10:53")
|
||||||
|
bad := netip.MustParseAddrPort("192.0.2.11:53")
|
||||||
|
|
||||||
|
mockClient := &mockUpstreamResolverPerServer{
|
||||||
|
responses: map[string]mockUpstreamResponse{
|
||||||
|
ok.String(): {msg: buildMockResponse(dns.RcodeSuccess, "192.0.2.100")},
|
||||||
|
bad.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||||
|
},
|
||||||
|
rtt: time.Millisecond,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
resolver := &upstreamResolverBase{
|
||||||
|
ctx: ctx,
|
||||||
|
upstreamClient: mockClient,
|
||||||
|
upstreamTimeout: UpstreamTimeout,
|
||||||
|
}
|
||||||
|
resolver.addRace([]netip.AddrPort{ok, bad})
|
||||||
|
|
||||||
|
responseWriter := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }}
|
||||||
|
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA))
|
||||||
|
|
||||||
|
health := resolver.UpstreamHealth()
|
||||||
|
require.Contains(t, health, ok)
|
||||||
|
assert.False(t, health[ok].LastOk.IsZero(), "ok upstream should have LastOk set")
|
||||||
|
assert.Empty(t, health[ok].LastErr)
|
||||||
|
|
||||||
|
// bad upstream was never tried because ok answered first; its health
|
||||||
|
// should remain unset.
|
||||||
|
assert.NotContains(t, health, bad, "sibling upstream should not be queried when primary answers")
|
||||||
|
}
|
||||||
|
|
||||||
func TestFormatFailures(t *testing.T) {
|
func TestFormatFailures(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -665,10 +742,10 @@ func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
|
|||||||
// Verify that a client EDNS0 larger than our MTU-derived limit gets
|
// Verify that a client EDNS0 larger than our MTU-derived limit gets
|
||||||
// capped in the outgoing request so the upstream doesn't send a
|
// capped in the outgoing request so the upstream doesn't send a
|
||||||
// response larger than our read buffer.
|
// response larger than our read buffer.
|
||||||
var receivedUDPSize uint16
|
var receivedUDPSize atomic.Uint32
|
||||||
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
if opt := r.IsEdns0(); opt != nil {
|
if opt := r.IsEdns0(); opt != nil {
|
||||||
receivedUDPSize = opt.UDPSize()
|
receivedUDPSize.Store(uint32(opt.UDPSize()))
|
||||||
}
|
}
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetReply(r)
|
m.SetReply(r)
|
||||||
@@ -699,7 +776,7 @@ func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
|
|||||||
require.NotNil(t, rm)
|
require.NotNil(t, rm)
|
||||||
|
|
||||||
expectedMax := uint16(currentMTU - ipUDPHeaderSize)
|
expectedMax := uint16(currentMTU - ipUDPHeaderSize)
|
||||||
assert.Equal(t, expectedMax, receivedUDPSize,
|
assert.Equal(t, expectedMax, uint16(receivedUDPSize.Load()),
|
||||||
"upstream should see capped EDNS0, not the client's 4096")
|
"upstream should see capped EDNS0, not the client's 4096")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -504,16 +504,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
|
|
||||||
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
||||||
|
|
||||||
e.dnsServer.SetRouteChecker(func(ip netip.Addr) bool {
|
e.dnsServer.SetRouteSources(e.routeManager.GetSelectedClientRoutes, e.routeManager.GetActiveClientRoutes)
|
||||||
for _, routes := range e.routeManager.GetSelectedClientRoutes() {
|
|
||||||
for _, r := range routes {
|
|
||||||
if r.Network.Contains(ip) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
})
|
|
||||||
|
|
||||||
if err = e.wgInterfaceCreate(); err != nil {
|
if err = e.wgInterfaceCreate(); err != nil {
|
||||||
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
|
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
|
||||||
@@ -1336,9 +1327,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
|
|
||||||
e.networkSerial = serial
|
e.networkSerial = serial
|
||||||
|
|
||||||
// Test received (upstream) servers for availability right away instead of upon usage.
|
|
||||||
// If no server of a server group responds this will disable the respective handler and retry later.
|
|
||||||
go e.dnsServer.ProbeAvailability()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1827,7 +1815,7 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
|
|||||||
return dnsServer, nil
|
return dnsServer, nil
|
||||||
|
|
||||||
case "ios":
|
case "ios":
|
||||||
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.mobileDep.HostDNSAddresses, e.statusRecorder, e.config.DisableDNS)
|
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
|
||||||
return dnsServer, nil
|
return dnsServer, nil
|
||||||
|
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ type Manager interface {
|
|||||||
GetRouteSelector() *routeselector.RouteSelector
|
GetRouteSelector() *routeselector.RouteSelector
|
||||||
GetClientRoutes() route.HAMap
|
GetClientRoutes() route.HAMap
|
||||||
GetSelectedClientRoutes() route.HAMap
|
GetSelectedClientRoutes() route.HAMap
|
||||||
|
GetActiveClientRoutes() route.HAMap
|
||||||
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
|
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
|
||||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||||
InitialRouteRange() []string
|
InitialRouteRange() []string
|
||||||
@@ -477,6 +478,39 @@ func (m *DefaultManager) GetSelectedClientRoutes() route.HAMap {
|
|||||||
return m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes))
|
return m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetActiveClientRoutes returns the subset of selected client routes
|
||||||
|
// that are currently reachable: the route's peer is Connected and is
|
||||||
|
// the one actively carrying the route (not just an HA sibling).
|
||||||
|
func (m *DefaultManager) GetActiveClientRoutes() route.HAMap {
|
||||||
|
m.mux.Lock()
|
||||||
|
selected := m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes))
|
||||||
|
recorder := m.statusRecorder
|
||||||
|
m.mux.Unlock()
|
||||||
|
|
||||||
|
if recorder == nil {
|
||||||
|
return selected
|
||||||
|
}
|
||||||
|
|
||||||
|
out := make(route.HAMap, len(selected))
|
||||||
|
for id, routes := range selected {
|
||||||
|
for _, r := range routes {
|
||||||
|
st, err := recorder.GetPeer(r.Peer)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if st.ConnStatus != peer.StatusConnected {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, hasRoute := st.GetRoutes()[r.Network.String()]; !hasRoute {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out[id] = routes
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
|
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
|
||||||
func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
||||||
m.mux.Lock()
|
m.mux.Lock()
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ type MockManager struct {
|
|||||||
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
||||||
GetClientRoutesFunc func() route.HAMap
|
GetClientRoutesFunc func() route.HAMap
|
||||||
GetSelectedClientRoutesFunc func() route.HAMap
|
GetSelectedClientRoutesFunc func() route.HAMap
|
||||||
|
GetActiveClientRoutesFunc func() route.HAMap
|
||||||
GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route
|
GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route
|
||||||
StopFunc func(manager *statemanager.Manager)
|
StopFunc func(manager *statemanager.Manager)
|
||||||
}
|
}
|
||||||
@@ -78,6 +79,14 @@ func (m *MockManager) GetSelectedClientRoutes() route.HAMap {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetActiveClientRoutes mock implementation of GetActiveClientRoutes from the Manager interface
|
||||||
|
func (m *MockManager) GetActiveClientRoutes() route.HAMap {
|
||||||
|
if m.GetActiveClientRoutesFunc != nil {
|
||||||
|
return m.GetActiveClientRoutesFunc()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface
|
// GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface
|
||||||
func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
||||||
if m.GetClientRoutesWithNetIDFunc != nil {
|
if m.GetClientRoutesWithNetIDFunc != nil {
|
||||||
|
|||||||
@@ -161,11 +161,7 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
|
|||||||
cfg.WgIface = interfaceName
|
cfg.WgIface = interfaceName
|
||||||
|
|
||||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||||
hostDNS := []netip.AddrPort{
|
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
|
||||||
netip.MustParseAddrPort("9.9.9.9:53"),
|
|
||||||
netip.MustParseAddrPort("149.112.112.112:53"),
|
|
||||||
}
|
|
||||||
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, hostDNS, c.stateFile)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop the internal client and free the resources
|
// Stop the internal client and free the resources
|
||||||
|
|||||||
Reference in New Issue
Block a user