diff --git a/client/internal/dns.go b/client/internal/dns.go index f5040ee49..a6604810f 100644 --- a/client/internal/dns.go +++ b/client/internal/dns.go @@ -12,52 +12,83 @@ import ( nbdns "github.com/netbirdio/netbird/dns" ) -func createPTRRecord(aRecord nbdns.SimpleRecord, prefix netip.Prefix) (nbdns.SimpleRecord, bool) { - ip, err := netip.ParseAddr(aRecord.RData) +func createPTRRecord(record nbdns.SimpleRecord, prefix netip.Prefix) (nbdns.SimpleRecord, bool) { + ip, err := netip.ParseAddr(record.RData) if err != nil { - log.Warnf("failed to parse IP address %s: %v", aRecord.RData, err) + log.Warnf("failed to parse IP address %s: %v", record.RData, err) return nbdns.SimpleRecord{}, false } + ip = ip.Unmap() if !prefix.Contains(ip) { return nbdns.SimpleRecord{}, false } - ipOctets := strings.Split(ip.String(), ".") - slices.Reverse(ipOctets) - rdnsName := dns.Fqdn(strings.Join(ipOctets, ".") + ".in-addr.arpa") + var rdnsName string + if ip.Is4() { + octets := strings.Split(ip.String(), ".") + slices.Reverse(octets) + rdnsName = dns.Fqdn(strings.Join(octets, ".") + ".in-addr.arpa") + } else { + // Expand to full 32 nibbles in reverse order (LSB first) per RFC 3596. + raw := ip.As16() + nibbles := make([]string, 32) + for i := 0; i < 16; i++ { + nibbles[31-i*2] = fmt.Sprintf("%x", raw[i]>>4) + nibbles[31-i*2-1] = fmt.Sprintf("%x", raw[i]&0x0f) + } + rdnsName = dns.Fqdn(strings.Join(nibbles, ".") + ".ip6.arpa") + } return nbdns.SimpleRecord{ Name: rdnsName, Type: int(dns.TypePTR), - Class: aRecord.Class, - TTL: aRecord.TTL, - RData: dns.Fqdn(aRecord.Name), + Class: record.Class, + TTL: record.TTL, + RData: dns.Fqdn(record.Name), }, true } -// generateReverseZoneName creates the reverse DNS zone name for a given network +// generateReverseZoneName creates the reverse DNS zone name for a given network. +// For IPv4 it produces an in-addr.arpa name, for IPv6 an ip6.arpa name. func generateReverseZoneName(network netip.Prefix) (string, error) { - networkIP := network.Masked().Addr() + networkIP := network.Masked().Addr().Unmap() + bits := network.Bits() - if !networkIP.Is4() { - return "", fmt.Errorf("reverse DNS is only supported for IPv4 networks, got: %s", networkIP) + if networkIP.Is4() { + // Round up to nearest byte. + octetsToUse := (bits + 7) / 8 + + octets := strings.Split(networkIP.String(), ".") + if octetsToUse > len(octets) { + return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", bits) + } + + reverseOctets := make([]string, octetsToUse) + for i := 0; i < octetsToUse; i++ { + reverseOctets[octetsToUse-1-i] = octets[i] + } + + return dns.Fqdn(strings.Join(reverseOctets, ".") + ".in-addr.arpa"), nil } - // round up to nearest byte - octetsToUse := (network.Bits() + 7) / 8 + // IPv6: round up to nearest nibble (4-bit boundary). + nibblesToUse := (bits + 3) / 4 - octets := strings.Split(networkIP.String(), ".") - if octetsToUse > len(octets) { - return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", network.Bits()) + raw := networkIP.As16() + allNibbles := make([]string, 32) + for i := 0; i < 16; i++ { + allNibbles[i*2] = fmt.Sprintf("%x", raw[i]>>4) + allNibbles[i*2+1] = fmt.Sprintf("%x", raw[i]&0x0f) } - reverseOctets := make([]string, octetsToUse) - for i := 0; i < octetsToUse; i++ { - reverseOctets[octetsToUse-1-i] = octets[i] + // Take the first nibblesToUse nibbles (network portion), reverse them. + used := make([]string, nibblesToUse) + for i := 0; i < nibblesToUse; i++ { + used[nibblesToUse-1-i] = allNibbles[i] } - return dns.Fqdn(strings.Join(reverseOctets, ".") + ".in-addr.arpa"), nil + return dns.Fqdn(strings.Join(used, ".") + ".ip6.arpa"), nil } // zoneExists checks if a zone with the given name already exists in the configuration @@ -71,7 +102,7 @@ func zoneExists(config *nbdns.Config, zoneName string) bool { return false } -// collectPTRRecords gathers all PTR records for the given network from A records +// collectPTRRecords gathers all PTR records for the given network from A and AAAA records. func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.SimpleRecord { var records []nbdns.SimpleRecord @@ -80,7 +111,7 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple continue } for _, record := range zone.Records { - if record.Type != int(dns.TypeA) { + if record.Type != int(dns.TypeA) && record.Type != int(dns.TypeAAAA) { continue } diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index b3908f163..0f4eb6bf8 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -298,6 +298,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) { if ip, err := netip.ParseAddr(address); err == nil && !ip.IsUnspecified() { ip = ip.Unmap() serverAddresses = append(serverAddresses, ip) + // Prefer the first IPv4 server as ServerIP since our DNS listener is IPv4. if !dnsSettings.ServerIP.IsValid() && ip.Is4() { dnsSettings.ServerIP = ip } diff --git a/client/internal/dns/network_manager_unix.go b/client/internal/dns/network_manager_unix.go index e4ccc8cbd..b5b21dc39 100644 --- a/client/internal/dns/network_manager_unix.go +++ b/client/internal/dns/network_manager_unix.go @@ -110,8 +110,15 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, st connSettings.cleanDeprecatedSettings() - convDNSIP := binary.LittleEndian.Uint32(config.ServerIP.AsSlice()) - connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSKey] = dbus.MakeVariant([]uint32{convDNSIP}) + ipKey := networkManagerDbusIPv4Key + if config.ServerIP.Is6() { + ipKey = networkManagerDbusIPv6Key + raw := config.ServerIP.As16() + connSettings[ipKey][networkManagerDbusDNSKey] = dbus.MakeVariant([][]byte{raw[:]}) + } else { + convDNSIP := binary.LittleEndian.Uint32(config.ServerIP.AsSlice()) + connSettings[ipKey][networkManagerDbusDNSKey] = dbus.MakeVariant([]uint32{convDNSIP}) + } var ( searchDomains []string matchDomains []string @@ -146,8 +153,8 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, st n.routingAll = false } - connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority) - connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList) + connSettings[ipKey][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority) + connSettings[ipKey][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList) state := &ShutdownState{ ManagerType: networkManager, diff --git a/client/internal/dns/systemd_linux.go b/client/internal/dns/systemd_linux.go index d9854c033..573dff540 100644 --- a/client/internal/dns/systemd_linux.go +++ b/client/internal/dns/systemd_linux.go @@ -90,8 +90,12 @@ func (s *systemdDbusConfigurator) supportCustomPort() bool { } func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { + family := int32(unix.AF_INET) + if config.ServerIP.Is6() { + family = unix.AF_INET6 + } defaultLinkInput := systemdDbusDNSInput{ - Family: unix.AF_INET, + Family: family, Address: config.ServerIP.AsSlice(), } if err := s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput}); err != nil { diff --git a/client/internal/dns/upstream_ios.go b/client/internal/dns/upstream_ios.go index 4d053a5a1..236c4d8e5 100644 --- a/client/internal/dns/upstream_ios.go +++ b/client/internal/dns/upstream_ios.go @@ -21,6 +21,8 @@ type upstreamResolverIOS struct { *upstreamResolverBase lIP netip.Addr lNet netip.Prefix + lIPv6 netip.Addr + lNetV6 netip.Prefix interfaceName string } @@ -37,6 +39,8 @@ func newUpstreamResolver( upstreamResolverBase: upstreamResolverBase, lIP: wgIface.Address().IP, lNet: wgIface.Address().Network, + lIPv6: wgIface.Address().IPv6, + lNetV6: wgIface.Address().IPv6Net, interfaceName: wgIface.Name(), } ios.upstreamClient = ios @@ -65,11 +69,27 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r * } else { upstreamIP = upstreamIP.Unmap() } - if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() { - log.Debugf("using private client to query upstream: %s", upstream) - client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout) - if err != nil { - return nil, 0, fmt.Errorf("error while creating private client: %s", err) + // TODO: IsPrivate is a rough heuristic. It misses public IPs routed through + // the tunnel (e.g. 9.9.9.9 via network route) and incorrectly matches local + // LAN private IPs. Replace with a check against the active route table or + // the set of routed prefixes from the network map. + needsPrivate := u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() || + (u.lNetV6.IsValid() && u.lNetV6.Contains(upstreamIP)) + if needsPrivate { + var bindIP netip.Addr + switch { + case upstreamIP.Is6() && u.lIPv6.IsValid(): + bindIP = u.lIPv6 + case upstreamIP.Is4() && u.lIP.IsValid(): + bindIP = u.lIP + } + + if bindIP.IsValid() { + log.Debugf("using private client to query upstream: %s", upstream) + client, err = GetClientPrivate(bindIP, u.interfaceName, timeout) + if err != nil { + return nil, 0, fmt.Errorf("create private client: %s", err) + } } } @@ -86,16 +106,18 @@ func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Dura return nil, err } + proto, opt := unix.IPPROTO_IP, unix.IP_BOUND_IF + if ip.Is6() { + proto, opt = unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF + } + dialer := &net.Dialer{ - LocalAddr: &net.UDPAddr{ - IP: ip.AsSlice(), - Port: 0, // Let the OS pick a free port - }, + LocalAddr: net.UDPAddrFromAddrPort(netip.AddrPortFrom(ip, 0)), Timeout: dialTimeout, Control: func(network, address string, c syscall.RawConn) error { var operr error fn := func(s uintptr) { - operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, index) + operr = unix.SetsockoptInt(int(s), proto, opt, index) } if err := c.Control(fn); err != nil { diff --git a/client/internal/dns_test.go b/client/internal/dns_test.go new file mode 100644 index 000000000..e15cc8fb7 --- /dev/null +++ b/client/internal/dns_test.go @@ -0,0 +1,138 @@ +package internal + +import ( + "net/netip" + "testing" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbdns "github.com/netbirdio/netbird/dns" +) + +func TestCreatePTRRecord_IPv4(t *testing.T) { + record := nbdns.SimpleRecord{ + Name: "peer1.netbird.cloud.", + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: "100.64.0.5", + } + prefix := netip.MustParsePrefix("100.64.0.0/16") + + ptr, ok := createPTRRecord(record, prefix) + require.True(t, ok) + assert.Equal(t, "5.0.64.100.in-addr.arpa.", ptr.Name) + assert.Equal(t, int(dns.TypePTR), ptr.Type) + assert.Equal(t, "peer1.netbird.cloud.", ptr.RData) +} + +func TestCreatePTRRecord_IPv6(t *testing.T) { + record := nbdns.SimpleRecord{ + Name: "peer1.netbird.cloud.", + Type: int(dns.TypeAAAA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: "fd00:1234:5678::1", + } + prefix := netip.MustParsePrefix("fd00:1234:5678::/48") + + ptr, ok := createPTRRecord(record, prefix) + require.True(t, ok) + assert.Equal(t, "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.7.6.5.4.3.2.1.0.0.d.f.ip6.arpa.", ptr.Name) + assert.Equal(t, int(dns.TypePTR), ptr.Type) + assert.Equal(t, "peer1.netbird.cloud.", ptr.RData) +} + +func TestCreatePTRRecord_OutOfRange(t *testing.T) { + record := nbdns.SimpleRecord{ + Name: "peer1.netbird.cloud.", + Type: int(dns.TypeA), + RData: "10.0.0.1", + } + prefix := netip.MustParsePrefix("100.64.0.0/16") + + _, ok := createPTRRecord(record, prefix) + assert.False(t, ok) +} + +func TestGenerateReverseZoneName_IPv4(t *testing.T) { + tests := []struct { + prefix string + expected string + }{ + {"100.64.0.0/16", "64.100.in-addr.arpa."}, + {"10.0.0.0/8", "10.in-addr.arpa."}, + {"192.168.1.0/24", "1.168.192.in-addr.arpa."}, + } + + for _, tt := range tests { + t.Run(tt.prefix, func(t *testing.T) { + zone, err := generateReverseZoneName(netip.MustParsePrefix(tt.prefix)) + require.NoError(t, err) + assert.Equal(t, tt.expected, zone) + }) + } +} + +func TestGenerateReverseZoneName_IPv6(t *testing.T) { + tests := []struct { + prefix string + expected string + }{ + {"fd00:1234:5678::/48", "8.7.6.5.4.3.2.1.0.0.d.f.ip6.arpa."}, + {"fd00::/16", "0.0.d.f.ip6.arpa."}, + {"fd12:3456:789a:bcde::/64", "e.d.c.b.a.9.8.7.6.5.4.3.2.1.d.f.ip6.arpa."}, + } + + for _, tt := range tests { + t.Run(tt.prefix, func(t *testing.T) { + zone, err := generateReverseZoneName(netip.MustParsePrefix(tt.prefix)) + require.NoError(t, err) + assert.Equal(t, tt.expected, zone) + }) + } +} + +func TestCollectPTRRecords_BothFamilies(t *testing.T) { + config := &nbdns.Config{ + CustomZones: []nbdns.CustomZone{ + { + Domain: "netbird.cloud.", + Records: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud.", Type: int(dns.TypeA), RData: "100.64.0.1"}, + {Name: "peer1.netbird.cloud.", Type: int(dns.TypeAAAA), RData: "fd00::1"}, + {Name: "peer2.netbird.cloud.", Type: int(dns.TypeA), RData: "100.64.0.2"}, + }, + }, + }, + } + + v4Records := collectPTRRecords(config, netip.MustParsePrefix("100.64.0.0/16")) + assert.Len(t, v4Records, 2, "should collect 2 A record PTRs for the v4 prefix") + + v6Records := collectPTRRecords(config, netip.MustParsePrefix("fd00::/64")) + assert.Len(t, v6Records, 1, "should collect 1 AAAA record PTR for the v6 prefix") +} + +func TestAddReverseZone_IPv6(t *testing.T) { + config := &nbdns.Config{ + CustomZones: []nbdns.CustomZone{ + { + Domain: "netbird.cloud.", + Records: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud.", Type: int(dns.TypeAAAA), RData: "fd00:1234:5678::1"}, + }, + }, + }, + } + + addReverseZone(config, netip.MustParsePrefix("fd00:1234:5678::/48")) + + require.Len(t, config.CustomZones, 2) + reverseZone := config.CustomZones[1] + assert.Equal(t, "8.7.6.5.4.3.2.1.0.0.d.f.ip6.arpa.", reverseZone.Domain) + assert.Len(t, reverseZone.Records, 1) + assert.Equal(t, int(dns.TypePTR), reverseZone.Records[0].Type) +} diff --git a/client/internal/engine.go b/client/internal/engine.go index 2fc1617b4..e34bec00d 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -28,11 +28,10 @@ import ( "github.com/netbirdio/netbird/client/firewall" firewallManager "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" - "github.com/netbirdio/netbird/client/iface/wgaddr" - "github.com/netbirdio/netbird/shared/netiputil" "github.com/netbirdio/netbird/client/iface/device" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/udpmux" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/debug" "github.com/netbirdio/netbird/client/internal/dns" @@ -63,6 +62,7 @@ import ( mgm "github.com/netbirdio/netbird/shared/management/client" "github.com/netbirdio/netbird/shared/management/domain" mgmProto "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/shared/netiputil" auth "github.com/netbirdio/netbird/shared/relay/auth/hmac" relayClient "github.com/netbirdio/netbird/shared/relay/client" signal "github.com/netbirdio/netbird/shared/signal/client" @@ -1252,7 +1252,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { protoDNSConfig = &mgmProto.DNSConfig{} } - dnsConfig := toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network) + dnsConfig := toDNSConfig(protoDNSConfig, e.wgInterface.Address()) if err := e.dnsServer.UpdateDNSServer(serial, dnsConfig); err != nil { log.Errorf("failed to update dns server, err: %v", err) @@ -1407,7 +1407,9 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE return entries } -func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config { +func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, addr wgaddr.Address) nbdns.Config { + network := addr.Network + networkV6 := addr.IPv6Net //nolint forwarderPort := uint16(protoDNSConfig.GetForwarderPort()) if forwarderPort == 0 { @@ -1464,6 +1466,9 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns if len(dnsUpdate.CustomZones) > 0 { addReverseZone(&dnsUpdate, network) + if networkV6.IsValid() { + addReverseZone(&dnsUpdate, networkV6) + } } return dnsUpdate @@ -1789,7 +1794,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err return nil, nil, false, err } routes := toRoutes(netMap.GetRoutes()) - dnsCfg := toDNSConfig(netMap.GetDNSConfig(), e.wgInterface.Address().Network) + dnsCfg := toDNSConfig(netMap.GetDNSConfig(), e.wgInterface.Address()) dnsFeatureFlag := toDNSFeatureFlag(netMap) return routes, &dnsCfg, dnsFeatureFlag, nil }