diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 267c1ed80..a4651ebb5 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -94,7 +94,7 @@ func NewDefaultServer( var dnsService service if wgInterface.IsUserspaceBind() { - dnsService = newServiceViaMemory(wgInterface) + dnsService = NewServiceViaMemory(wgInterface) } else { dnsService = newServiceViaListener(wgInterface, addrPort) } @@ -112,7 +112,7 @@ func NewDefaultServerPermanentUpstream( statusRecorder *peer.Status, ) *DefaultServer { log.Debugf("host dns address list is: %v", hostsDnsList) - ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder) + ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder) ds.hostsDNSHolder.set(hostsDnsList) ds.permanent = true ds.addHostRootZone() @@ -130,7 +130,7 @@ func NewDefaultServerIos( iosDnsManager IosDnsManager, statusRecorder *peer.Status, ) *DefaultServer { - ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder) + ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder) ds.iosDnsManager = iosDnsManager return ds } diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 6cbd9ea15..b9552bc17 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -534,7 +534,7 @@ func TestDNSServerStartStop(t *testing.T) { func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { hostManager := &mockHostConfigurator{} server := DefaultServer{ - service: newServiceViaMemory(&mocWGIface{}), + service: NewServiceViaMemory(&mocWGIface{}), localResolver: &localResolver{ registeredMap: make(registrationMap), }, diff --git a/client/internal/dns/service_memory.go b/client/internal/dns/service_memory.go index 757cd962a..729b90cc0 100644 --- a/client/internal/dns/service_memory.go +++ b/client/internal/dns/service_memory.go @@ -12,7 +12,7 @@ import ( log "github.com/sirupsen/logrus" ) -type serviceViaMemory struct { +type ServiceViaMemory struct { wgInterface WGIface dnsMux *dns.ServeMux runtimeIP string @@ -22,8 +22,8 @@ type serviceViaMemory struct { listenerFlagLock sync.Mutex } -func newServiceViaMemory(wgIface WGIface) *serviceViaMemory { - s := &serviceViaMemory{ +func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory { + s := &ServiceViaMemory{ wgInterface: wgIface, dnsMux: dns.NewServeMux(), @@ -33,7 +33,7 @@ func newServiceViaMemory(wgIface WGIface) *serviceViaMemory { return s } -func (s *serviceViaMemory) Listen() error { +func (s *ServiceViaMemory) Listen() error { s.listenerFlagLock.Lock() defer s.listenerFlagLock.Unlock() @@ -52,7 +52,7 @@ func (s *serviceViaMemory) Listen() error { return nil } -func (s *serviceViaMemory) Stop() { +func (s *ServiceViaMemory) Stop() { s.listenerFlagLock.Lock() defer s.listenerFlagLock.Unlock() @@ -67,23 +67,23 @@ func (s *serviceViaMemory) Stop() { s.listenerIsRunning = false } -func (s *serviceViaMemory) RegisterMux(pattern string, handler dns.Handler) { +func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) { s.dnsMux.Handle(pattern, handler) } -func (s *serviceViaMemory) DeregisterMux(pattern string) { +func (s *ServiceViaMemory) DeregisterMux(pattern string) { s.dnsMux.HandleRemove(pattern) } -func (s *serviceViaMemory) RuntimePort() int { +func (s *ServiceViaMemory) RuntimePort() int { return s.runtimePort } -func (s *serviceViaMemory) RuntimeIP() string { +func (s *ServiceViaMemory) RuntimeIP() string { return s.runtimeIP } -func (s *serviceViaMemory) filterDNSTraffic() (string, error) { +func (s *ServiceViaMemory) filterDNSTraffic() (string, error) { filter := s.wgInterface.GetFilter() if filter == nil { return "", fmt.Errorf("can't set DNS filter, filter not initialized") diff --git a/client/internal/dns/upstream_ios.go b/client/internal/dns/upstream_ios.go index 0c01a013e..60ed79d87 100644 --- a/client/internal/dns/upstream_ios.go +++ b/client/internal/dns/upstream_ios.go @@ -4,6 +4,7 @@ package dns import ( "context" + "fmt" "net" "syscall" "time" @@ -17,9 +18,9 @@ import ( type upstreamResolverIOS struct { *upstreamResolverBase - lIP net.IP - lNet *net.IPNet - iIndex int + lIP net.IP + lNet *net.IPNet + interfaceName string } func newUpstreamResolver( @@ -32,17 +33,11 @@ func newUpstreamResolver( ) (*upstreamResolverIOS, error) { upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder) - index, err := getInterfaceIndex(interfaceName) - if err != nil { - log.Debugf("unable to get interface index for %s: %s", interfaceName, err) - return nil, err - } - ios := &upstreamResolverIOS{ upstreamResolverBase: upstreamResolverBase, lIP: ip, lNet: net, - iIndex: index, + interfaceName: interfaceName, } ios.upstreamClient = ios @@ -53,7 +48,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r * client := &dns.Client{} upstreamHost, _, err := net.SplitHostPort(upstream) if err != nil { - log.Errorf("error while parsing upstream host: %s", err) + return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err) } timeout := upstreamTimeout @@ -65,26 +60,35 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r * upstreamIP := net.ParseIP(upstreamHost) if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) { log.Debugf("using private client to query upstream: %s", upstream) - client = u.getClientPrivate(timeout) + client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout) + if err != nil { + return nil, 0, fmt.Errorf("error while creating private client: %s", err) + } } // Cannot use client.ExchangeContext because it overwrites our Dialer return client.Exchange(r, upstream) } -// getClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface +// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface // This method is needed for iOS -func (u *upstreamResolverIOS) getClientPrivate(dialTimeout time.Duration) *dns.Client { +func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { + index, err := getInterfaceIndex(interfaceName) + if err != nil { + log.Debugf("unable to get interface index for %s: %s", interfaceName, err) + return nil, err + } + dialer := &net.Dialer{ LocalAddr: &net.UDPAddr{ - IP: u.lIP, + IP: ip, Port: 0, // Let the OS pick a free port }, Timeout: dialTimeout, Control: func(network, address string, c syscall.RawConn) error { var operr error fn := func(s uintptr) { - operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, u.iIndex) + operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, index) } if err := c.Control(fn); err != nil { @@ -101,7 +105,7 @@ func (u *upstreamResolverIOS) getClientPrivate(dialTimeout time.Duration) *dns.C client := &dns.Client{ Dialer: dialer, } - return client + return client, nil } func getInterfaceIndex(interfaceName string) (int, error) { diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index 92c71b1e0..1566d10dd 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -10,6 +10,7 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" + nbdns "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager/dynamic" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" @@ -65,7 +66,7 @@ func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration routePeersNotifiers: make(map[string]chan struct{}), routeUpdate: make(chan routesUpdate), peerStateUpdate: make(chan struct{}), - handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder), + handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder, wgInterface), } return client } @@ -383,9 +384,10 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() { } } -func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status) RouteHandler { +func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface *iface.WGIface) RouteHandler { if rt.IsDynamic() { - return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder) + dns := nbdns.NewServiceViaMemory(wgInterface) + return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort())) } return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter) } diff --git a/client/internal/routemanager/dynamic/route.go b/client/internal/routemanager/dynamic/route.go index e95710798..3296f3ddf 100644 --- a/client/internal/routemanager/dynamic/route.go +++ b/client/internal/routemanager/dynamic/route.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/util" + "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/route" ) @@ -47,6 +48,8 @@ type Route struct { currentPeerKey string cancel context.CancelFunc statusRecorder *peer.Status + wgInterface *iface.WGIface + resolverAddr string } func NewRoute( @@ -55,6 +58,8 @@ func NewRoute( allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, interval time.Duration, statusRecorder *peer.Status, + wgInterface *iface.WGIface, + resolverAddr string, ) *Route { return &Route{ route: rt, @@ -63,6 +68,8 @@ func NewRoute( interval: interval, dynamicDomains: domainMap{}, statusRecorder: statusRecorder, + wgInterface: wgInterface, + resolverAddr: resolverAddr, } } @@ -228,11 +235,17 @@ func (r *Route) resolve(results chan resolveResult) { wg.Add(1) go func(domain domain.Domain) { defer wg.Done() - ips, err := net.LookupIP(string(domain)) + + ips, err := r.getIPsFromResolver(domain) if err != nil { - results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)} - return + log.Tracef("Failed to resolve domain %s with private resolver: %v", domain.SafeString(), err) + ips, err = net.LookupIP(string(domain)) + if err != nil { + results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)} + return + } } + for _, ip := range ips { prefix, err := util.GetPrefixFromIP(ip) if err != nil { diff --git a/client/internal/routemanager/dynamic/route_generic.go b/client/internal/routemanager/dynamic/route_generic.go new file mode 100644 index 000000000..cf3d913a4 --- /dev/null +++ b/client/internal/routemanager/dynamic/route_generic.go @@ -0,0 +1,13 @@ +//go:build !ios + +package dynamic + +import ( + "net" + + "github.com/netbirdio/netbird/management/domain" +) + +func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) { + return net.LookupIP(string(domain)) +} diff --git a/client/internal/routemanager/dynamic/route_ios.go b/client/internal/routemanager/dynamic/route_ios.go new file mode 100644 index 000000000..67138222f --- /dev/null +++ b/client/internal/routemanager/dynamic/route_ios.go @@ -0,0 +1,55 @@ +//go:build ios + +package dynamic + +import ( + "fmt" + "net" + "time" + + "github.com/miekg/dns" + + nbdns "github.com/netbirdio/netbird/client/internal/dns" + + "github.com/netbirdio/netbird/management/domain" +) + +const dialTimeout = 10 * time.Second + +func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) { + privateClient, err := nbdns.GetClientPrivate(r.wgInterface.Address().IP, r.wgInterface.Name(), dialTimeout) + if err != nil { + return nil, fmt.Errorf("error while creating private client: %s", err) + } + + msg := new(dns.Msg) + msg.SetQuestion(dns.Fqdn(string(domain)), dns.TypeA) + + startTime := time.Now() + + response, _, err := privateClient.Exchange(msg, r.resolverAddr) + if err != nil { + return nil, fmt.Errorf("DNS query for %s failed after %s: %s ", domain.SafeString(), time.Since(startTime), err) + } + + if response.Rcode != dns.RcodeSuccess { + return nil, fmt.Errorf("dns response code: %s", dns.RcodeToString[response.Rcode]) + } + + ips := make([]net.IP, 0) + + for _, answ := range response.Answer { + if aRecord, ok := answ.(*dns.A); ok { + ips = append(ips, aRecord.A) + } + if aaaaRecord, ok := answ.(*dns.AAAA); ok { + ips = append(ips, aaaaRecord.AAAA) + } + } + + if len(ips) == 0 { + return nil, fmt.Errorf("no A or AAAA records found for %s", domain.SafeString()) + } + + return ips, nil +}