Compare commits

...

9 Commits

Author SHA1 Message Date
Viktor Liu
ff5eddf70b Merge branch 'main' into add-ns-punnycode-support 2025-06-08 13:14:52 +02:00
Viktor Liu
273160c682 [client] Use punycode domains internally consequently (#3867) 2025-05-24 18:25:15 +02:00
bcmmbaga
1d6c360aec fix tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-23 13:07:26 +03:00
bcmmbaga
f04e7c3f06 Merge branch 'main' into add-ns-punnycode-support
# Conflicts:
#	management/server/nameserver.go
#	management/server/nameserver_test.go
2025-05-23 13:00:19 +03:00
bcmmbaga
3d89cd43c2 fix tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-22 22:44:30 +03:00
bcmmbaga
0eeda712d0 add support for punycode domain
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-22 22:44:12 +03:00
bcmmbaga
3e3268db5f Remove support for wildcard ns match domain
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-22 19:01:53 +03:00
bcmmbaga
31f0879e71 remove the leading dot and root dot support ns regex
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-22 18:51:05 +03:00
bcmmbaga
f25b5bb987 Enhance match domain validation logic and add test cases
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-22 16:35:45 +03:00
38 changed files with 287 additions and 259 deletions

View File

@@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/management/domain"
) )
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform // WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
@@ -43,11 +44,11 @@ func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind
} }
} }
func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) { func (t *WGTunDevice) Create(routes []string, dns string, searchDomains domain.List) (WGConfigurer, error) {
log.Info("create tun interface") log.Info("create tun interface")
routesString := routesToString(routes) routesString := routesToString(routes)
searchDomainsToString := searchDomainsToString(searchDomains) searchDomainsToString := searchDomainsToString(searchDomains.ToPunycodeList())
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString) fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
if err != nil { if err != nil {

View File

@@ -8,10 +8,11 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/management/domain"
) )
type WGTunDevice interface { type WGTunDevice interface {
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error) Create(routes []string, dns string, searchDomains domain.List) (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error) Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address WgAddress() wgaddr.Address

View File

@@ -2,7 +2,11 @@
package iface package iface
import "fmt" import (
"fmt"
"github.com/netbirdio/netbird/management/domain"
)
// Create creates a new Wireguard interface, sets a given IP and brings it up. // Create creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one. // Will reuse an existing one.
@@ -21,6 +25,6 @@ func (w *WGIface) Create() error {
} }
// CreateOnAndroid this function make sense on mobile only // CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnAndroid([]string, string, []string) error { func (w *WGIface) CreateOnAndroid([]string, string, domain.List) error {
return fmt.Errorf("this function has not implemented on non mobile") return fmt.Errorf("this function has not implemented on non mobile")
} }

View File

@@ -2,11 +2,13 @@ package iface
import ( import (
"fmt" "fmt"
"github.com/netbirdio/netbird/management/domain"
) )
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up. // CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one. // Will reuse an existing one.
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error { func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains domain.List) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()

View File

@@ -7,6 +7,8 @@ import (
"time" "time"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
"github.com/netbirdio/netbird/management/domain"
) )
// Create creates a new Wireguard interface, sets a given IP and brings it up. // Create creates a new Wireguard interface, sets a given IP and brings it up.
@@ -36,6 +38,6 @@ func (w *WGIface) Create() error {
} }
// CreateOnAndroid this function make sense on mobile only // CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnAndroid([]string, string, []string) error { func (w *WGIface) CreateOnAndroid([]string, string, domain.List) error {
return fmt.Errorf("this function has not implemented on this platform") return fmt.Errorf("this function has not implemented on this platform")
} }

View File

@@ -239,7 +239,7 @@ func searchDomains(config HostDNSConfig) []string {
continue continue
} }
listOfDomains = append(listOfDomains, strings.TrimSuffix(dConf.Domain, ".")) listOfDomains = append(listOfDomains, strings.TrimSuffix(dConf.Domain.PunycodeString(), "."))
} }
return listOfDomains return listOfDomains
} }

View File

@@ -8,6 +8,8 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/domain"
) )
const ( const (
@@ -24,8 +26,8 @@ type SubdomainMatcher interface {
type HandlerEntry struct { type HandlerEntry struct {
Handler dns.Handler Handler dns.Handler
Priority int Priority int
Pattern string Pattern domain.Domain
OrigPattern string OrigPattern domain.Domain
IsWildcard bool IsWildcard bool
MatchSubdomains bool MatchSubdomains bool
} }
@@ -39,7 +41,7 @@ type HandlerChain struct {
// ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain // ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain
type ResponseWriterChain struct { type ResponseWriterChain struct {
dns.ResponseWriter dns.ResponseWriter
origPattern string origPattern domain.Domain
shouldContinue bool shouldContinue bool
} }
@@ -59,18 +61,18 @@ func NewHandlerChain() *HandlerChain {
} }
// GetOrigPattern returns the original pattern of the handler that wrote the response // GetOrigPattern returns the original pattern of the handler that wrote the response
func (w *ResponseWriterChain) GetOrigPattern() string { func (w *ResponseWriterChain) GetOrigPattern() domain.Domain {
return w.origPattern return w.origPattern
} }
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority // AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int) { func (c *HandlerChain) AddHandler(pattern domain.Domain, handler dns.Handler, priority int) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
pattern = strings.ToLower(dns.Fqdn(pattern)) pattern = domain.Domain(strings.ToLower(dns.Fqdn(pattern.PunycodeString())))
origPattern := pattern origPattern := pattern
isWildcard := strings.HasPrefix(pattern, "*.") isWildcard := strings.HasPrefix(pattern.PunycodeString(), "*.")
if isWildcard { if isWildcard {
pattern = pattern[2:] pattern = pattern[2:]
} }
@@ -110,8 +112,8 @@ func (c *HandlerChain) findHandlerPosition(newEntry HandlerEntry) int {
// domain specificity next // domain specificity next
if h.Priority == newEntry.Priority { if h.Priority == newEntry.Priority {
newDots := strings.Count(newEntry.Pattern, ".") newDots := strings.Count(newEntry.Pattern.PunycodeString(), ".")
existingDots := strings.Count(h.Pattern, ".") existingDots := strings.Count(h.Pattern.PunycodeString(), ".")
if newDots > existingDots { if newDots > existingDots {
return i return i
} }
@@ -123,20 +125,20 @@ func (c *HandlerChain) findHandlerPosition(newEntry HandlerEntry) int {
} }
// RemoveHandler removes a handler for the given pattern and priority // RemoveHandler removes a handler for the given pattern and priority
func (c *HandlerChain) RemoveHandler(pattern string, priority int) { func (c *HandlerChain) RemoveHandler(pattern domain.Domain, priority int) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
pattern = dns.Fqdn(pattern) pattern = domain.Domain(dns.Fqdn(pattern.PunycodeString()))
c.removeEntry(pattern, priority) c.removeEntry(pattern, priority)
} }
func (c *HandlerChain) removeEntry(pattern string, priority int) { func (c *HandlerChain) removeEntry(pattern domain.Domain, priority int) {
// Find and remove handlers matching both original pattern (case-insensitive) and priority // Find and remove handlers matching both original pattern (case-insensitive) and priority
for i := len(c.handlers) - 1; i >= 0; i-- { for i := len(c.handlers) - 1; i >= 0; i-- {
entry := c.handlers[i] entry := c.handlers[i]
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority { if strings.EqualFold(entry.OrigPattern.PunycodeString(), pattern.PunycodeString()) && entry.Priority == priority {
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...) c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
break break
} }
@@ -201,16 +203,16 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
case entry.Pattern == ".": case entry.Pattern == ".":
return true return true
case entry.IsWildcard: case entry.IsWildcard:
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".") parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern.PunycodeString()), ".")
return len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern) return len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern.PunycodeString())
default: default:
// For non-wildcard patterns: // For non-wildcard patterns:
// If handler wants subdomain matching, allow suffix match // If handler wants subdomain matching, allow suffix match
// Otherwise require exact match // Otherwise require exact match
if entry.MatchSubdomains { if entry.MatchSubdomains {
return strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern) return strings.EqualFold(qname, entry.Pattern.PunycodeString()) || strings.HasSuffix(qname, "."+entry.Pattern.PunycodeString())
} else { } else {
return strings.EqualFold(qname, entry.Pattern) return strings.EqualFold(qname, entry.Pattern.PunycodeString())
} }
} }
} }

View File

@@ -9,6 +9,7 @@ import (
nbdns "github.com/netbirdio/netbird/client/internal/dns" nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dns/test" "github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/management/domain"
) )
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order // TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
@@ -50,8 +51,8 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) { func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
handlerDomain string handlerDomain domain.Domain
queryDomain string queryDomain domain.Domain
isWildcard bool isWildcard bool
matchSubdomains bool matchSubdomains bool
shouldMatch bool shouldMatch bool
@@ -141,7 +142,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
chain.AddHandler(pattern, handler, nbdns.PriorityDefault) chain.AddHandler(pattern, handler, nbdns.PriorityDefault)
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tt.queryDomain, dns.TypeA) r.SetQuestion(tt.queryDomain.PunycodeString(), dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
chain.ServeDNS(w, r) chain.ServeDNS(w, r)
@@ -160,17 +161,17 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
handlers []struct { handlers []struct {
pattern string pattern domain.Domain
priority int priority int
} }
queryDomain string queryDomain domain.Domain
expectedCalls int expectedCalls int
expectedHandler int // index of the handler that should be called expectedHandler int // index of the handler that should be called
}{ }{
{ {
name: "wildcard and exact same priority - exact should win", name: "wildcard and exact same priority - exact should win",
handlers: []struct { handlers: []struct {
pattern string pattern domain.Domain
priority int priority int
}{ }{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault}, {pattern: "*.example.com.", priority: nbdns.PriorityDefault},
@@ -183,7 +184,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
{ {
name: "higher priority wildcard over lower priority exact", name: "higher priority wildcard over lower priority exact",
handlers: []struct { handlers: []struct {
pattern string pattern domain.Domain
priority int priority int
}{ }{
{pattern: "example.com.", priority: nbdns.PriorityDefault}, {pattern: "example.com.", priority: nbdns.PriorityDefault},
@@ -196,7 +197,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
{ {
name: "multiple wildcards different priorities", name: "multiple wildcards different priorities",
handlers: []struct { handlers: []struct {
pattern string pattern domain.Domain
priority int priority int
}{ }{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault}, {pattern: "*.example.com.", priority: nbdns.PriorityDefault},
@@ -210,7 +211,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
{ {
name: "subdomain with mix of patterns", name: "subdomain with mix of patterns",
handlers: []struct { handlers: []struct {
pattern string pattern domain.Domain
priority int priority int
}{ }{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault}, {pattern: "*.example.com.", priority: nbdns.PriorityDefault},
@@ -224,7 +225,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
{ {
name: "root zone with specific domain", name: "root zone with specific domain",
handlers: []struct { handlers: []struct {
pattern string pattern domain.Domain
priority int priority int
}{ }{
{pattern: ".", priority: nbdns.PriorityDefault}, {pattern: ".", priority: nbdns.PriorityDefault},
@@ -258,7 +259,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
// Create and execute request // Create and execute request
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tt.queryDomain, dns.TypeA) r.SetQuestion(tt.queryDomain.PunycodeString(), dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
chain.ServeDNS(w, r) chain.ServeDNS(w, r)
@@ -330,7 +331,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
name string name string
ops []struct { ops []struct {
action string // "add" or "remove" action string // "add" or "remove"
pattern string pattern domain.Domain
priority int priority int
} }
query string query string
@@ -340,7 +341,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
name: "remove high priority keeps lower priority handler", name: "remove high priority keeps lower priority handler",
ops: []struct { ops: []struct {
action string action string
pattern string pattern domain.Domain
priority int priority int
}{ }{
{"add", "example.com.", nbdns.PriorityDNSRoute}, {"add", "example.com.", nbdns.PriorityDNSRoute},
@@ -357,7 +358,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
name: "remove lower priority keeps high priority handler", name: "remove lower priority keeps high priority handler",
ops: []struct { ops: []struct {
action string action string
pattern string pattern domain.Domain
priority int priority int
}{ }{
{"add", "example.com.", nbdns.PriorityDNSRoute}, {"add", "example.com.", nbdns.PriorityDNSRoute},
@@ -374,7 +375,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
name: "remove all handlers in order", name: "remove all handlers in order",
ops: []struct { ops: []struct {
action string action string
pattern string pattern domain.Domain
priority int priority int
}{ }{
{"add", "example.com.", nbdns.PriorityDNSRoute}, {"add", "example.com.", nbdns.PriorityDNSRoute},
@@ -436,7 +437,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
func TestHandlerChain_MultiPriorityHandling(t *testing.T) { func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
chain := nbdns.NewHandlerChain() chain := nbdns.NewHandlerChain()
testDomain := "example.com." testDomain := domain.Domain("example.com.")
testQuery := "test.example.com." testQuery := "test.example.com."
// Create handlers with MatchSubdomains enabled // Create handlers with MatchSubdomains enabled
@@ -518,7 +519,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
name string name string
scenario string scenario string
addHandlers []struct { addHandlers []struct {
pattern string pattern domain.Domain
priority int priority int
subdomains bool subdomains bool
shouldMatch bool shouldMatch bool
@@ -530,7 +531,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
name: "case insensitive exact match", name: "case insensitive exact match",
scenario: "handler registered lowercase, query uppercase", scenario: "handler registered lowercase, query uppercase",
addHandlers: []struct { addHandlers: []struct {
pattern string pattern domain.Domain
priority int priority int
subdomains bool subdomains bool
shouldMatch bool shouldMatch bool
@@ -544,7 +545,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
name: "case insensitive wildcard match", name: "case insensitive wildcard match",
scenario: "handler registered mixed case wildcard, query different case", scenario: "handler registered mixed case wildcard, query different case",
addHandlers: []struct { addHandlers: []struct {
pattern string pattern domain.Domain
priority int priority int
subdomains bool subdomains bool
shouldMatch bool shouldMatch bool
@@ -558,7 +559,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
name: "multiple handlers different case same domain", name: "multiple handlers different case same domain",
scenario: "second handler should replace first despite case difference", scenario: "second handler should replace first despite case difference",
addHandlers: []struct { addHandlers: []struct {
pattern string pattern domain.Domain
priority int priority int
subdomains bool subdomains bool
shouldMatch bool shouldMatch bool
@@ -573,7 +574,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
name: "subdomain matching case insensitive", name: "subdomain matching case insensitive",
scenario: "handler with MatchSubdomains true should match regardless of case", scenario: "handler with MatchSubdomains true should match regardless of case",
addHandlers: []struct { addHandlers: []struct {
pattern string pattern domain.Domain
priority int priority int
subdomains bool subdomains bool
shouldMatch bool shouldMatch bool
@@ -587,7 +588,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
name: "root zone case insensitive", name: "root zone case insensitive",
scenario: "root zone handler should match regardless of case", scenario: "root zone handler should match regardless of case",
addHandlers: []struct { addHandlers: []struct {
pattern string pattern domain.Domain
priority int priority int
subdomains bool subdomains bool
shouldMatch bool shouldMatch bool
@@ -601,7 +602,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
name: "multiple handlers different priority", name: "multiple handlers different priority",
scenario: "should call higher priority handler despite case differences", scenario: "should call higher priority handler despite case differences",
addHandlers: []struct { addHandlers: []struct {
pattern string pattern domain.Domain
priority int priority int
subdomains bool subdomains bool
shouldMatch bool shouldMatch bool
@@ -618,7 +619,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
chain := nbdns.NewHandlerChain() chain := nbdns.NewHandlerChain()
handlerCalls := make(map[string]bool) // track which patterns were called handlerCalls := make(map[domain.Domain]bool) // track which patterns were called
// Add handlers according to test case // Add handlers according to test case
for _, h := range tt.addHandlers { for _, h := range tt.addHandlers {
@@ -686,19 +687,19 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
scenario string scenario string
ops []struct { ops []struct {
action string action string
pattern string pattern domain.Domain
priority int priority int
subdomain bool subdomain bool
} }
query string query domain.Domain
expectedMatch string expectedMatch domain.Domain
}{ }{
{ {
name: "more specific domain matches first", name: "more specific domain matches first",
scenario: "sub.example.com should match before example.com", scenario: "sub.example.com should match before example.com",
ops: []struct { ops: []struct {
action string action string
pattern string pattern domain.Domain
priority int priority int
subdomain bool subdomain bool
}{ }{
@@ -713,7 +714,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
scenario: "sub.example.com should match before example.com", scenario: "sub.example.com should match before example.com",
ops: []struct { ops: []struct {
action string action string
pattern string pattern domain.Domain
priority int priority int
subdomain bool subdomain bool
}{ }{
@@ -728,7 +729,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
scenario: "after removing most specific, should fall back to less specific", scenario: "after removing most specific, should fall back to less specific",
ops: []struct { ops: []struct {
action string action string
pattern string pattern domain.Domain
priority int priority int
subdomain bool subdomain bool
}{ }{
@@ -745,7 +746,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
scenario: "less specific domain with higher priority should match first", scenario: "less specific domain with higher priority should match first",
ops: []struct { ops: []struct {
action string action string
pattern string pattern domain.Domain
priority int priority int
subdomain bool subdomain bool
}{ }{
@@ -760,7 +761,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
scenario: "with equal priority, more specific domain should match", scenario: "with equal priority, more specific domain should match",
ops: []struct { ops: []struct {
action string action string
pattern string pattern domain.Domain
priority int priority int
subdomain bool subdomain bool
}{ }{
@@ -776,7 +777,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
scenario: "specific domain should match before wildcard at same priority", scenario: "specific domain should match before wildcard at same priority",
ops: []struct { ops: []struct {
action string action string
pattern string pattern domain.Domain
priority int priority int
subdomain bool subdomain bool
}{ }{
@@ -791,7 +792,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
chain := nbdns.NewHandlerChain() chain := nbdns.NewHandlerChain()
handlers := make(map[string]*nbdns.MockSubdomainHandler) handlers := make(map[domain.Domain]*nbdns.MockSubdomainHandler)
for _, op := range tt.ops { for _, op := range tt.ops {
if op.action == "add" { if op.action == "add" {
@@ -804,7 +805,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
} }
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tt.query, dns.TypeA) r.SetQuestion(tt.query.PunycodeString(), dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Setup handler expectations // Setup handler expectations
@@ -836,9 +837,9 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) { func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
addPattern string addPattern domain.Domain
removePattern string removePattern domain.Domain
queryPattern string queryPattern domain.Domain
shouldBeRemoved bool shouldBeRemoved bool
description string description string
}{ }{
@@ -954,7 +955,7 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
handler := &nbdns.MockHandler{} handler := &nbdns.MockHandler{}
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tt.queryPattern, dns.TypeA) r.SetQuestion(tt.queryPattern.PunycodeString(), dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// First verify no handler is called before adding any // First verify no handler is called before adding any

View File

@@ -9,6 +9,7 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
) )
var ErrRouteAllWithoutNameserverGroup = fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured") var ErrRouteAllWithoutNameserverGroup = fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
@@ -40,7 +41,7 @@ type HostDNSConfig struct {
type DomainConfig struct { type DomainConfig struct {
Disabled bool `json:"disabled"` Disabled bool `json:"disabled"`
Domain string `json:"domain"` Domain domain.Domain `json:"domain"`
MatchOnly bool `json:"matchOnly"` MatchOnly bool `json:"matchOnly"`
} }
@@ -103,18 +104,20 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
config.RouteAll = true config.RouteAll = true
} }
for _, domain := range nsConfig.Domains { for _, d := range nsConfig.Domains {
d := strings.ToLower(dns.Fqdn(d.PunycodeString()))
config.Domains = append(config.Domains, DomainConfig{ config.Domains = append(config.Domains, DomainConfig{
Domain: strings.ToLower(dns.Fqdn(domain)), Domain: domain.Domain(d),
MatchOnly: !nsConfig.SearchDomainsEnabled, MatchOnly: !nsConfig.SearchDomainsEnabled,
}) })
} }
} }
for _, customZone := range dnsConfig.CustomZones { for _, customZone := range dnsConfig.CustomZones {
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone) d := strings.ToLower(dns.Fqdn(customZone.Domain))
matchOnly := strings.HasSuffix(d, ipv4ReverseZone) || strings.HasSuffix(d, ipv6ReverseZone)
config.Domains = append(config.Domains, DomainConfig{ config.Domains = append(config.Domains, DomainConfig{
Domain: strings.ToLower(dns.Fqdn(customZone.Domain)), Domain: domain.Domain(d),
MatchOnly: matchOnly, MatchOnly: matchOnly,
}) })
} }

View File

@@ -79,10 +79,10 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
continue continue
} }
if dConf.MatchOnly { if dConf.MatchOnly {
matchDomains = append(matchDomains, strings.TrimSuffix(dConf.Domain, ".")) matchDomains = append(matchDomains, strings.TrimSuffix(dConf.Domain.PunycodeString(), "."))
continue continue
} }
searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain, ".")) searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain.PunycodeString(), "."))
} }
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)

View File

@@ -186,9 +186,9 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
continue continue
} }
if !dConf.MatchOnly { if !dConf.MatchOnly {
searchDomains = append(searchDomains, strings.TrimSuffix(dConf.Domain, ".")) searchDomains = append(searchDomains, strings.TrimSuffix(dConf.Domain.PunycodeString(), "."))
} }
matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, ".")) matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain.PunycodeString(), "."))
} }
if len(matchDomains) != 0 { if len(matchDomains) != 0 {

View File

@@ -62,8 +62,8 @@ func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
return fmt.Errorf("method UpdateDNSServer is not implemented") return fmt.Errorf("method UpdateDNSServer is not implemented")
} }
func (m *MockServer) SearchDomains() []string { func (m *MockServer) SearchDomains() domain.List {
return make([]string, 0) return make(domain.List, 0)
} }
// ProbeAvailability mocks implementation of ProbeAvailability from the Server interface // ProbeAvailability mocks implementation of ProbeAvailability from the Server interface

View File

@@ -125,10 +125,10 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, st
continue continue
} }
if dConf.MatchOnly { if dConf.MatchOnly {
matchDomains = append(matchDomains, "~."+dConf.Domain) matchDomains = append(matchDomains, "~."+dConf.Domain.PunycodeString())
continue continue
} }
searchDomains = append(searchDomains, dConf.Domain) searchDomains = append(searchDomains, dConf.Domain.PunycodeString())
} }
newDomainList := append(searchDomains, matchDomains...) //nolint:gocritic newDomainList := append(searchDomains, matchDomains...) //nolint:gocritic

View File

@@ -1,21 +1,19 @@
package dns package dns
import ( import (
"reflect"
"sort"
"sync" "sync"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/management/domain"
) )
type notifier struct { type notifier struct {
listener listener.NetworkChangeListener listener listener.NetworkChangeListener
listenerMux sync.Mutex listenerMux sync.Mutex
searchDomains []string searchDomains domain.List
} }
func newNotifier(initialSearchDomains []string) *notifier { func newNotifier(initialSearchDomains domain.List) *notifier {
sort.Strings(initialSearchDomains)
return &notifier{ return &notifier{
searchDomains: initialSearchDomains, searchDomains: initialSearchDomains,
} }
@@ -27,16 +25,8 @@ func (n *notifier) setListener(listener listener.NetworkChangeListener) {
n.listener = listener n.listener = listener
} }
func (n *notifier) onNewSearchDomains(searchDomains []string) { func (n *notifier) onNewSearchDomains(searchDomains domain.List) {
sort.Strings(searchDomains) if searchDomains.Equal(n.searchDomains) {
if len(n.searchDomains) != len(searchDomains) {
n.searchDomains = searchDomains
n.notify()
return
}
if reflect.DeepEqual(n.searchDomains, searchDomains) {
return return
} }

View File

@@ -44,12 +44,12 @@ type Server interface {
DnsIP() string DnsIP() string
UpdateDNSServer(serial uint64, update nbdns.Config) error UpdateDNSServer(serial uint64, update nbdns.Config) error
OnUpdatedHostDNSServer(strings []string) OnUpdatedHostDNSServer(strings []string)
SearchDomains() []string SearchDomains() domain.List
ProbeAvailability() ProbeAvailability()
} }
type nsGroupsByDomain struct { type nsGroupsByDomain struct {
domain string domain domain.Domain
groups []*nbdns.NameServerGroup groups []*nbdns.NameServerGroup
} }
@@ -90,7 +90,7 @@ type handlerWithStop interface {
} }
type handlerWrapper struct { type handlerWrapper struct {
domain string domain domain.Domain
handler handlerWithStop handler handlerWithStop
priority int priority int
} }
@@ -197,7 +197,7 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
s.registerHandler(domains.ToPunycodeList(), handler, priority) s.registerHandler(domains, handler, priority)
// TODO: This will take over zones for non-wildcard domains, for which we might not have a handler in the chain // TODO: This will take over zones for non-wildcard domains, for which we might not have a handler in the chain
for _, domain := range domains { for _, domain := range domains {
@@ -207,7 +207,7 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler
s.applyHostConfig() s.applyHostConfig()
} }
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) { func (s *DefaultServer) registerHandler(domains domain.List, handler dns.Handler, priority int) {
log.Debugf("registering handler %s with priority %d", handler, priority) log.Debugf("registering handler %s with priority %d", handler, priority)
for _, domain := range domains { for _, domain := range domains {
@@ -224,7 +224,7 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
s.deregisterHandler(domains.ToPunycodeList(), priority) s.deregisterHandler(domains, priority)
for _, domain := range domains { for _, domain := range domains {
zone := toZone(domain) zone := toZone(domain)
s.extraDomains[zone]-- s.extraDomains[zone]--
@@ -235,7 +235,7 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
s.applyHostConfig() s.applyHostConfig()
} }
func (s *DefaultServer) deregisterHandler(domains []string, priority int) { func (s *DefaultServer) deregisterHandler(domains domain.List, priority int) {
log.Debugf("deregistering handler %v with priority %d", domains, priority) log.Debugf("deregistering handler %v with priority %d", domains, priority)
for _, domain := range domains { for _, domain := range domains {
@@ -378,8 +378,8 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
return nil return nil
} }
func (s *DefaultServer) SearchDomains() []string { func (s *DefaultServer) SearchDomains() domain.List {
var searchDomains []string var searchDomains domain.List
for _, dConf := range s.currentConfig.Domains { for _, dConf := range s.currentConfig.Domains {
if dConf.Disabled { if dConf.Disabled {
@@ -472,18 +472,16 @@ func (s *DefaultServer) applyHostConfig() {
config := s.currentConfig config := s.currentConfig
existingDomains := make(map[string]struct{}) existingDomains := make(map[domain.Domain]struct{})
for _, d := range config.Domains { for _, d := range config.Domains {
existingDomains[d.Domain] = struct{}{} existingDomains[d.Domain] = struct{}{}
} }
// add extra domains only if they're not already in the config // add extra domains only if they're not already in the config
for domain := range s.extraDomains { for d := range s.extraDomains {
domainStr := domain.PunycodeString() if _, exists := existingDomains[d]; !exists {
if _, exists := existingDomains[domainStr]; !exists {
config.Domains = append(config.Domains, DomainConfig{ config.Domains = append(config.Domains, DomainConfig{
Domain: domainStr, Domain: d,
MatchOnly: true, MatchOnly: true,
}) })
} }
@@ -525,7 +523,7 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
} }
muxUpdates = append(muxUpdates, handlerWrapper{ muxUpdates = append(muxUpdates, handlerWrapper{
domain: customZone.Domain, domain: domain.Domain(customZone.Domain),
handler: s.localResolver, handler: s.localResolver,
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
}) })
@@ -647,7 +645,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) { func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
// this will introduce a short period of time when the server is not able to handle DNS requests // this will introduce a short period of time when the server is not able to handle DNS requests
for _, existing := range s.dnsMuxMap { for _, existing := range s.dnsMuxMap {
s.deregisterHandler([]string{existing.domain}, existing.priority) s.deregisterHandler(domain.List{existing.domain}, existing.priority)
existing.handler.Stop() existing.handler.Stop()
} }
@@ -658,7 +656,7 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
if update.domain == nbdns.RootZone { if update.domain == nbdns.RootZone {
containsRootUpdate = true containsRootUpdate = true
} }
s.registerHandler([]string{update.domain}, update.handler, update.priority) s.registerHandler(domain.List{update.domain}, update.handler, update.priority)
muxUpdateMap[update.handler.ID()] = update muxUpdateMap[update.handler.ID()] = update
} }
@@ -687,7 +685,7 @@ func (s *DefaultServer) upstreamCallbacks(
handler dns.Handler, handler dns.Handler,
priority int, priority int,
) (deactivate func(error), reactivate func()) { ) (deactivate func(error), reactivate func()) {
var removeIndex map[string]int var removeIndex map[domain.Domain]int
deactivate = func(err error) { deactivate = func(err error) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
@@ -695,20 +693,20 @@ func (s *DefaultServer) upstreamCallbacks(
l := log.WithField("nameservers", nsGroup.NameServers) l := log.WithField("nameservers", nsGroup.NameServers)
l.Info("Temporarily deactivating nameservers group due to timeout") l.Info("Temporarily deactivating nameservers group due to timeout")
removeIndex = make(map[string]int) removeIndex = make(map[domain.Domain]int)
for _, domain := range nsGroup.Domains { for _, domain := range nsGroup.Domains {
removeIndex[domain] = -1 removeIndex[domain] = -1
} }
if nsGroup.Primary { if nsGroup.Primary {
removeIndex[nbdns.RootZone] = -1 removeIndex[nbdns.RootZone] = -1
s.currentConfig.RouteAll = false s.currentConfig.RouteAll = false
s.deregisterHandler([]string{nbdns.RootZone}, priority) s.deregisterHandler(domain.List{nbdns.RootZone}, priority)
} }
for i, item := range s.currentConfig.Domains { for i, item := range s.currentConfig.Domains {
if _, found := removeIndex[item.Domain]; found { if _, found := removeIndex[item.Domain]; found {
s.currentConfig.Domains[i].Disabled = true s.currentConfig.Domains[i].Disabled = true
s.deregisterHandler([]string{item.Domain}, priority) s.deregisterHandler(domain.List{item.Domain}, priority)
removeIndex[item.Domain] = i removeIndex[item.Domain] = i
} }
} }
@@ -732,12 +730,12 @@ func (s *DefaultServer) upstreamCallbacks(
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
for domain, i := range removeIndex { for d, i := range removeIndex {
if i == -1 || i >= len(s.currentConfig.Domains) || s.currentConfig.Domains[i].Domain != domain { if i == -1 || i >= len(s.currentConfig.Domains) || s.currentConfig.Domains[i].Domain != d{
continue continue
} }
s.currentConfig.Domains[i].Disabled = false s.currentConfig.Domains[i].Disabled = false
s.registerHandler([]string{domain}, handler, priority) s.registerHandler(domain.List{d}, handler, priority)
} }
l := log.WithField("nameservers", nsGroup.NameServers) l := log.WithField("nameservers", nsGroup.NameServers)
@@ -745,7 +743,7 @@ func (s *DefaultServer) upstreamCallbacks(
if nsGroup.Primary { if nsGroup.Primary {
s.currentConfig.RouteAll = true s.currentConfig.RouteAll = true
s.registerHandler([]string{nbdns.RootZone}, handler, priority) s.registerHandler(domain.List{nbdns.RootZone}, handler, priority)
} }
s.applyHostConfig() s.applyHostConfig()
@@ -777,7 +775,7 @@ func (s *DefaultServer) addHostRootZone() {
handler.deactivate = func(error) {} handler.deactivate = func(error) {}
handler.reactivate = func() {} handler.reactivate = func() {}
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault) s.registerHandler(domain.List{nbdns.RootZone}, handler, PriorityDefault)
} }
func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) { func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
@@ -792,7 +790,7 @@ func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
state := peer.NSGroupState{ state := peer.NSGroupState{
ID: generateGroupKey(group), ID: generateGroupKey(group),
Servers: servers, Servers: servers,
Domains: group.Domains, Domains: group.Domains.ToPunycodeList(),
// The probe will determine the state, default enabled // The probe will determine the state, default enabled
Enabled: true, Enabled: true,
Error: nil, Error: nil,
@@ -825,7 +823,7 @@ func generateGroupKey(nsGroup *nbdns.NameServerGroup) string {
// groupNSGroupsByDomain groups nameserver groups by their match domains // groupNSGroupsByDomain groups nameserver groups by their match domains
func groupNSGroupsByDomain(nsGroups []*nbdns.NameServerGroup) []nsGroupsByDomain { func groupNSGroupsByDomain(nsGroups []*nbdns.NameServerGroup) []nsGroupsByDomain {
domainMap := make(map[string][]*nbdns.NameServerGroup) domainMap := make(map[domain.Domain][]*nbdns.NameServerGroup)
for _, group := range nsGroups { for _, group := range nsGroups {
if group.Primary { if group.Primary {

View File

@@ -6,7 +6,6 @@ import (
"net" "net"
"net/netip" "net/netip"
"os" "os"
"strings"
"testing" "testing"
"time" "time"
@@ -96,7 +95,7 @@ func init() {
formatter.SetTextFormatter(log.StandardLogger()) formatter.SetTextFormatter(log.StandardLogger())
} }
func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase { func generateDummyHandler(domain domain.Domain, servers []nbdns.NameServer) *upstreamResolverBase {
var srvs []string var srvs []string
for _, srv := range servers { for _, srv := range servers {
srvs = append(srvs, getNSHostPort(srv)) srvs = append(srvs, getNSHostPort(srv))
@@ -151,7 +150,7 @@ func TestUpdateDNSServer(t *testing.T) {
}, },
NameServerGroups: []*nbdns.NameServerGroup{ NameServerGroups: []*nbdns.NameServerGroup{
{ {
Domains: []string{"netbird.io"}, Domains: domain.List{"netbird.io"},
NameServers: nameServers, NameServers: nameServers,
}, },
{ {
@@ -183,7 +182,7 @@ func TestUpdateDNSServer(t *testing.T) {
name: "New Config Should Succeed", name: "New Config Should Succeed",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}, initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
initUpstreamMap: registeredHandlerMap{ initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ generateDummyHandler(domain.Domain(zoneRecords[0].Name), nameServers).ID(): handlerWrapper{
domain: "netbird.cloud", domain: "netbird.cloud",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
@@ -201,7 +200,7 @@ func TestUpdateDNSServer(t *testing.T) {
}, },
NameServerGroups: []*nbdns.NameServerGroup{ NameServerGroups: []*nbdns.NameServerGroup{
{ {
Domains: []string{"netbird.io"}, Domains: domain.List{"netbird.io"},
NameServers: nameServers, NameServers: nameServers,
}, },
}, },
@@ -302,8 +301,8 @@ func TestUpdateDNSServer(t *testing.T) {
name: "Empty Config Should Succeed and Clean Maps", name: "Empty Config Should Succeed and Clean Maps",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}, initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
initUpstreamMap: registeredHandlerMap{ initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ generateDummyHandler(domain.Domain(zoneRecords[0].Name), nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name, domain: domain.Domain(zoneRecords[0].Name),
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
}, },
@@ -318,8 +317,8 @@ func TestUpdateDNSServer(t *testing.T) {
name: "Disabled Service Should clean map", name: "Disabled Service Should clean map",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}, initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
initUpstreamMap: registeredHandlerMap{ initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ generateDummyHandler(domain.Domain(zoneRecords[0].Name), nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name, domain: domain.Domain(zoneRecords[0].Name),
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
}, },
@@ -493,7 +492,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
dnsServer.dnsMuxMap = registeredHandlerMap{ dnsServer.dnsMuxMap = registeredHandlerMap{
"id1": handlerWrapper{ "id1": handlerWrapper{
domain: zoneRecords[0].Name, domain: domain.Domain(zoneRecords[0].Name),
handler: &local.Resolver{}, handler: &local.Resolver{},
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
}, },
@@ -525,7 +524,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
}, },
NameServerGroups: []*nbdns.NameServerGroup{ NameServerGroups: []*nbdns.NameServerGroup{
{ {
Domains: []string{"netbird.io"}, Domains: domain.List{"netbird.io"},
NameServers: nameServers, NameServers: nameServers,
}, },
{ {
@@ -591,7 +590,7 @@ func TestDNSServerStartStop(t *testing.T) {
t.Error(err) t.Error(err)
} }
dnsServer.registerHandler([]string{"netbird.cloud"}, dnsServer.localResolver, 1) dnsServer.registerHandler(domain.List{"netbird.cloud"}, dnsServer.localResolver, 1)
resolver := &net.Resolver{ resolver := &net.Resolver{
PreferGo: true, PreferGo: true,
@@ -651,48 +650,48 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
var domainsUpdate string var domainsUpdate string
hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error { hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error {
domains := []string{} domains := domain.List{}
for _, item := range config.Domains { for _, item := range config.Domains {
if item.Disabled { if item.Disabled {
continue continue
} }
domains = append(domains, item.Domain) domains = append(domains, item.Domain)
} }
domainsUpdate = strings.Join(domains, ",") domainsUpdate = domains.PunycodeString()
return nil return nil
} }
deactivate, reactivate := server.upstreamCallbacks(&nbdns.NameServerGroup{ deactivate, reactivate := server.upstreamCallbacks(&nbdns.NameServerGroup{
Domains: []string{"domain1"}, Domains: domain.List{"domain1"},
NameServers: []nbdns.NameServer{ NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53}, {IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53},
}, },
}, nil, 0) }, nil, 0)
deactivate(nil) deactivate(nil)
expected := "domain0,domain2" expected := "domain0, domain2"
domains := []string{} domains := domain.List{}
for _, item := range server.currentConfig.Domains { for _, item := range server.currentConfig.Domains {
if item.Disabled { if item.Disabled {
continue continue
} }
domains = append(domains, item.Domain) domains = append(domains, item.Domain)
} }
got := strings.Join(domains, ",") got := domains.PunycodeString()
if expected != got { if expected != got {
t.Errorf("expected domains list: %q, got %q", expected, got) t.Errorf("expected domains list: %q, got %q", expected, got)
} }
reactivate() reactivate()
expected = "domain0,domain1,domain2" expected = "domain0, domain1, domain2"
domains = []string{} domains = domain.List{}
for _, item := range server.currentConfig.Domains { for _, item := range server.currentConfig.Domains {
if item.Disabled { if item.Disabled {
continue continue
} }
domains = append(domains, item.Domain) domains = append(domains, item.Domain)
} }
got = strings.Join(domains, ",") got = domains.PunycodeString()
if expected != got { if expected != got {
t.Errorf("expected domains list: %q, got %q", expected, domainsUpdate) t.Errorf("expected domains list: %q, got %q", expected, domainsUpdate)
} }
@@ -860,7 +859,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
Port: 53, Port: 53,
}, },
}, },
Domains: []string{"google.com"}, Domains: domain.List{"google.com"},
Primary: false, Primary: false,
}, },
}, },
@@ -1115,7 +1114,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
name string name string
initialHandlers registeredHandlerMap initialHandlers registeredHandlerMap
updates []handlerWrapper updates []handlerWrapper
expectedHandlers map[string]string // map[HandlerID]domain expectedHandlers map[string]domain.Domain // map[HandlerID]domain
description string description string
}{ }{
{ {
@@ -1131,7 +1130,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
priority: PriorityMatchDomain - 1, priority: PriorityMatchDomain - 1,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]domain.Domain{
"upstream-group2": "example.com", "upstream-group2": "example.com",
}, },
description: "When group1 is not included in the update, it should be removed while group2 remains", description: "When group1 is not included in the update, it should be removed while group2 remains",
@@ -1149,7 +1148,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]domain.Domain{
"upstream-group1": "example.com", "upstream-group1": "example.com",
}, },
description: "When group2 is not included in the update, it should be removed while group1 remains", description: "When group2 is not included in the update, it should be removed while group1 remains",
@@ -1182,7 +1181,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
priority: PriorityMatchDomain - 1, priority: PriorityMatchDomain - 1,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]domain.Domain{
"upstream-group1": "example.com", "upstream-group1": "example.com",
"upstream-group2": "example.com", "upstream-group2": "example.com",
"upstream-group3": "example.com", "upstream-group3": "example.com",
@@ -1217,7 +1216,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
priority: PriorityMatchDomain - 2, priority: PriorityMatchDomain - 2,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]domain.Domain{
"upstream-group1": "example.com", "upstream-group1": "example.com",
"upstream-group2": "example.com", "upstream-group2": "example.com",
"upstream-group3": "example.com", "upstream-group3": "example.com",
@@ -1237,7 +1236,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
priority: PriorityDefault - 1, priority: PriorityDefault - 1,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]domain.Domain{
"upstream-root2": ".", "upstream-root2": ".",
}, },
description: "When root1 is not included in the update, it should be removed while root2 remains", description: "When root1 is not included in the update, it should be removed while root2 remains",
@@ -1254,7 +1253,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
priority: PriorityDefault, priority: PriorityDefault,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]domain.Domain{
"upstream-root1": ".", "upstream-root1": ".",
}, },
description: "When root2 is not included in the update, it should be removed while root1 remains", description: "When root2 is not included in the update, it should be removed while root1 remains",
@@ -1285,7 +1284,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
priority: PriorityDefault - 1, priority: PriorityDefault - 1,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]domain.Domain{
"upstream-root1": ".", "upstream-root1": ".",
"upstream-root2": ".", "upstream-root2": ".",
"upstream-root3": ".", "upstream-root3": ".",
@@ -1318,7 +1317,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
priority: PriorityDefault - 2, priority: PriorityDefault - 2,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]domain.Domain{
"upstream-root1": ".", "upstream-root1": ".",
"upstream-root2": ".", "upstream-root2": ".",
"upstream-root3": ".", "upstream-root3": ".",
@@ -1345,7 +1344,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]domain.Domain{
"upstream-group1": "example.com", "upstream-group1": "example.com",
"upstream-other": "other.com", "upstream-other": "other.com",
}, },
@@ -1384,7 +1383,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]domain.Domain{
"upstream-group1": "example.com", "upstream-group1": "example.com",
"upstream-group2": "example.com", "upstream-group2": "example.com",
"upstream-other": "other.com", "upstream-other": "other.com",
@@ -1440,7 +1439,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
for _, muxEntry := range server.dnsMuxMap { for _, muxEntry := range server.dnsMuxMap {
if chainEntry.Handler == muxEntry.handler && if chainEntry.Handler == muxEntry.handler &&
chainEntry.Priority == muxEntry.priority && chainEntry.Priority == muxEntry.priority &&
chainEntry.Pattern == dns.Fqdn(muxEntry.domain) { chainEntry.Pattern.PunycodeString() == dns.Fqdn(muxEntry.domain.PunycodeString()) {
foundInMux = true foundInMux = true
break break
} }
@@ -1459,8 +1458,8 @@ func TestExtraDomains(t *testing.T) {
registerDomains []domain.List registerDomains []domain.List
deregisterDomains []domain.List deregisterDomains []domain.List
finalConfig nbdns.Config finalConfig nbdns.Config
expectedDomains []string expectedDomains domain.List
expectedMatchOnly []string expectedMatchOnly domain.List
applyHostConfigCall int applyHostConfigCall int
}{ }{
{ {
@@ -1474,12 +1473,12 @@ func TestExtraDomains(t *testing.T) {
{Domain: "config.example.com"}, {Domain: "config.example.com"},
}, },
}, },
expectedDomains: []string{ expectedDomains: domain.List{
"config.example.com.", "config.example.com.",
"extra1.example.com.", "extra1.example.com.",
"extra2.example.com.", "extra2.example.com.",
}, },
expectedMatchOnly: []string{ expectedMatchOnly: domain.List{
"extra1.example.com.", "extra1.example.com.",
"extra2.example.com.", "extra2.example.com.",
}, },
@@ -1496,12 +1495,12 @@ func TestExtraDomains(t *testing.T) {
registerDomains: []domain.List{ registerDomains: []domain.List{
{"extra1.example.com", "extra2.example.com"}, {"extra1.example.com", "extra2.example.com"},
}, },
expectedDomains: []string{ expectedDomains: domain.List{
"config.example.com.", "config.example.com.",
"extra1.example.com.", "extra1.example.com.",
"extra2.example.com.", "extra2.example.com.",
}, },
expectedMatchOnly: []string{ expectedMatchOnly: domain.List{
"extra1.example.com.", "extra1.example.com.",
"extra2.example.com.", "extra2.example.com.",
}, },
@@ -1519,12 +1518,12 @@ func TestExtraDomains(t *testing.T) {
registerDomains: []domain.List{ registerDomains: []domain.List{
{"extra.example.com", "overlap.example.com"}, {"extra.example.com", "overlap.example.com"},
}, },
expectedDomains: []string{ expectedDomains: domain.List{
"config.example.com.", "config.example.com.",
"overlap.example.com.", "overlap.example.com.",
"extra.example.com.", "extra.example.com.",
}, },
expectedMatchOnly: []string{ expectedMatchOnly: domain.List{
"extra.example.com.", "extra.example.com.",
}, },
applyHostConfigCall: 2, applyHostConfigCall: 2,
@@ -1544,12 +1543,12 @@ func TestExtraDomains(t *testing.T) {
deregisterDomains: []domain.List{ deregisterDomains: []domain.List{
{"extra1.example.com", "extra3.example.com"}, {"extra1.example.com", "extra3.example.com"},
}, },
expectedDomains: []string{ expectedDomains: domain.List{
"config.example.com.", "config.example.com.",
"extra2.example.com.", "extra2.example.com.",
"extra4.example.com.", "extra4.example.com.",
}, },
expectedMatchOnly: []string{ expectedMatchOnly: domain.List{
"extra2.example.com.", "extra2.example.com.",
"extra4.example.com.", "extra4.example.com.",
}, },
@@ -1570,13 +1569,13 @@ func TestExtraDomains(t *testing.T) {
deregisterDomains: []domain.List{ deregisterDomains: []domain.List{
{"duplicate.example.com"}, {"duplicate.example.com"},
}, },
expectedDomains: []string{ expectedDomains: domain.List{
"config.example.com.", "config.example.com.",
"extra.example.com.", "extra.example.com.",
"other.example.com.", "other.example.com.",
"duplicate.example.com.", "duplicate.example.com.",
}, },
expectedMatchOnly: []string{ expectedMatchOnly: domain.List{
"extra.example.com.", "extra.example.com.",
"other.example.com.", "other.example.com.",
"duplicate.example.com.", "duplicate.example.com.",
@@ -1601,13 +1600,13 @@ func TestExtraDomains(t *testing.T) {
{Domain: "newconfig.example.com"}, {Domain: "newconfig.example.com"},
}, },
}, },
expectedDomains: []string{ expectedDomains: domain.List{
"config.example.com.", "config.example.com.",
"newconfig.example.com.", "newconfig.example.com.",
"extra.example.com.", "extra.example.com.",
"duplicate.example.com.", "duplicate.example.com.",
}, },
expectedMatchOnly: []string{ expectedMatchOnly: domain.List{
"extra.example.com.", "extra.example.com.",
"duplicate.example.com.", "duplicate.example.com.",
}, },
@@ -1628,12 +1627,12 @@ func TestExtraDomains(t *testing.T) {
deregisterDomains: []domain.List{ deregisterDomains: []domain.List{
{"protected.example.com"}, {"protected.example.com"},
}, },
expectedDomains: []string{ expectedDomains: domain.List{
"extra.example.com.", "extra.example.com.",
"config.example.com.", "config.example.com.",
"protected.example.com.", "protected.example.com.",
}, },
expectedMatchOnly: []string{ expectedMatchOnly: domain.List{
"extra.example.com.", "extra.example.com.",
}, },
applyHostConfigCall: 3, applyHostConfigCall: 3,
@@ -1644,7 +1643,7 @@ func TestExtraDomains(t *testing.T) {
ServiceEnable: true, ServiceEnable: true,
NameServerGroups: []*nbdns.NameServerGroup{ NameServerGroups: []*nbdns.NameServerGroup{
{ {
Domains: []string{"ns.example.com", "overlap.ns.example.com"}, Domains: domain.List{"ns.example.com", "overlap.ns.example.com"},
NameServers: []nbdns.NameServer{ NameServers: []nbdns.NameServer{
{ {
IP: netip.MustParseAddr("8.8.8.8"), IP: netip.MustParseAddr("8.8.8.8"),
@@ -1658,12 +1657,12 @@ func TestExtraDomains(t *testing.T) {
registerDomains: []domain.List{ registerDomains: []domain.List{
{"extra.example.com", "overlap.ns.example.com"}, {"extra.example.com", "overlap.ns.example.com"},
}, },
expectedDomains: []string{ expectedDomains: domain.List{
"ns.example.com.", "ns.example.com.",
"overlap.ns.example.com.", "overlap.ns.example.com.",
"extra.example.com.", "extra.example.com.",
}, },
expectedMatchOnly: []string{ expectedMatchOnly: domain.List{
"ns.example.com.", "ns.example.com.",
"overlap.ns.example.com.", "overlap.ns.example.com.",
"extra.example.com.", "extra.example.com.",
@@ -1734,8 +1733,8 @@ func TestExtraDomains(t *testing.T) {
lastConfig := capturedConfigs[len(capturedConfigs)-1] lastConfig := capturedConfigs[len(capturedConfigs)-1]
// Check all expected domains are present // Check all expected domains are present
domainMap := make(map[string]bool) domainMap := make(map[domain.Domain]bool)
matchOnlyMap := make(map[string]bool) matchOnlyMap := make(map[domain.Domain]bool)
for _, d := range lastConfig.Domains { for _, d := range lastConfig.Domains {
domainMap[d.Domain] = true domainMap[d.Domain] = true
@@ -1852,12 +1851,12 @@ func TestUpdateConfigWithExistingExtraDomains(t *testing.T) {
err := server.applyConfiguration(initialConfig) err := server.applyConfiguration(initialConfig)
assert.NoError(t, err) assert.NoError(t, err)
var domains []string var domains domain.List
for _, d := range capturedConfig.Domains { for _, d := range capturedConfig.Domains {
domains = append(domains, d.Domain) domains = append(domains, d.Domain)
} }
assert.Contains(t, domains, "config.example.com.") assert.Contains(t, domains, domain.Domain("config.example.com."))
assert.Contains(t, domains, "extra.example.com.") assert.Contains(t, domains, domain.Domain("extra.example.com."))
// Now apply a new configuration with overlapping domain // Now apply a new configuration with overlapping domain
updatedConfig := nbdns.Config{ updatedConfig := nbdns.Config{
@@ -1871,7 +1870,7 @@ func TestUpdateConfigWithExistingExtraDomains(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
// Verify both domains are in config, but no duplicates // Verify both domains are in config, but no duplicates
domains = []string{} domains = domain.List{}
matchOnlyCount := 0 matchOnlyCount := 0
for _, d := range capturedConfig.Domains { for _, d := range capturedConfig.Domains {
domains = append(domains, d.Domain) domains = append(domains, d.Domain)
@@ -1880,12 +1879,12 @@ func TestUpdateConfigWithExistingExtraDomains(t *testing.T) {
} }
} }
assert.Contains(t, domains, "config.example.com.") assert.Contains(t, domains, domain.Domain("config.example.com."))
assert.Contains(t, domains, "extra.example.com.") assert.Contains(t, domains, domain.Domain("extra.example.com."))
assert.Equal(t, 2, len(domains), "Should have exactly 2 domains with no duplicates") assert.Equal(t, 2, len(domains), "Should have exactly 2 domains with no duplicates")
// Extra domain should no longer be marked as match-only when in config // Extra domain should no longer be marked as match-only when in config
matchOnlyDomain := "" var matchOnlyDomain domain.Domain
for _, d := range capturedConfig.Domains { for _, d := range capturedConfig.Domains {
if d.Domain == "extra.example.com." && d.MatchOnly { if d.Domain == "extra.example.com." && d.MatchOnly {
matchOnlyDomain = d.Domain matchOnlyDomain = d.Domain
@@ -1938,10 +1937,10 @@ func TestDomainCaseHandling(t *testing.T) {
err := server.applyConfiguration(config) err := server.applyConfiguration(config)
assert.NoError(t, err) assert.NoError(t, err)
var domains []string var domains domain.List
for _, d := range capturedConfig.Domains { for _, d := range capturedConfig.Domains {
domains = append(domains, d.Domain) domains = append(domains, d.Domain)
} }
assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent") assert.Contains(t, domains, domain.Domain("config.example.com."), "Mixed case domain should be normalized and pre.sent")
assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present") assert.Contains(t, domains, domain.Domain("mixed.example.com."), "Mixed case domain should be normalized and present")
} }

View File

@@ -117,15 +117,15 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
continue continue
} }
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{ domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
Domain: dConf.Domain, Domain: dConf.Domain.PunycodeString(),
MatchOnly: dConf.MatchOnly, MatchOnly: dConf.MatchOnly,
}) })
if dConf.MatchOnly { if dConf.MatchOnly {
matchDomains = append(matchDomains, dConf.Domain) matchDomains = append(matchDomains, dConf.Domain.PunycodeString())
continue continue
} }
searchDomains = append(searchDomains, dConf.Domain) searchDomains = append(searchDomains, dConf.Domain.PunycodeString())
} }
if config.RouteAll { if config.RouteAll {

View File

@@ -22,6 +22,7 @@ import (
"github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/dns/types"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/management/domain"
) )
const ( const (
@@ -48,7 +49,7 @@ type upstreamResolverBase struct {
cancel context.CancelFunc cancel context.CancelFunc
upstreamClient upstreamClient upstreamClient upstreamClient
upstreamServers []string upstreamServers []string
domain string domain domain.Domain
disabled bool disabled bool
failsCount atomic.Int32 failsCount atomic.Int32
successCount atomic.Int32 successCount atomic.Int32
@@ -62,7 +63,7 @@ type upstreamResolverBase struct {
statusRecorder *peer.Status statusRecorder *peer.Status
} }
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase { func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain domain.Domain) *upstreamResolverBase {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
return &upstreamResolverBase{ return &upstreamResolverBase{

View File

@@ -10,6 +10,7 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/management/domain"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
@@ -28,7 +29,7 @@ func newUpstreamResolver(
_ netip.Prefix, _ netip.Prefix,
statusRecorder *peer.Status, statusRecorder *peer.Status,
hostsDNSHolder *hostsDNSHolder, hostsDNSHolder *hostsDNSHolder,
domain string, domain domain.Domain,
) (*upstreamResolver, error) { ) (*upstreamResolver, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain) upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
c := &upstreamResolver{ c := &upstreamResolver{

View File

@@ -10,6 +10,7 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/management/domain"
) )
type upstreamResolver struct { type upstreamResolver struct {
@@ -23,7 +24,7 @@ func newUpstreamResolver(
_ netip.Prefix, _ netip.Prefix,
statusRecorder *peer.Status, statusRecorder *peer.Status,
_ *hostsDNSHolder, _ *hostsDNSHolder,
domain string, domain domain.Domain,
) (*upstreamResolver, error) { ) (*upstreamResolver, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain) upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
nonIOS := &upstreamResolver{ nonIOS := &upstreamResolver{

View File

@@ -15,6 +15,7 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/management/domain"
) )
type upstreamResolverIOS struct { type upstreamResolverIOS struct {
@@ -31,7 +32,7 @@ func newUpstreamResolver(
net netip.Prefix, net netip.Prefix,
statusRecorder *peer.Status, statusRecorder *peer.Status,
_ *hostsDNSHolder, _ *hostsDNSHolder,
domain string, domain domain.Domain,
) (*upstreamResolverIOS, error) { ) (*upstreamResolverIOS, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain) upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)

View File

@@ -1165,7 +1165,7 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns
for _, nsGroup := range protoDNSConfig.GetNameServerGroups() { for _, nsGroup := range protoDNSConfig.GetNameServerGroups() {
dnsNSGroup := &nbdns.NameServerGroup{ dnsNSGroup := &nbdns.NameServerGroup{
Primary: nsGroup.GetPrimary(), Primary: nsGroup.GetPrimary(),
Domains: nsGroup.GetDomains(), Domains: domain.FromPunycodeList(nsGroup.GetDomains()),
SearchDomainsEnabled: nsGroup.GetSearchDomainsEnabled(), SearchDomainsEnabled: nsGroup.GetSearchDomainsEnabled(),
} }
for _, ns := range nsGroup.GetNameServers() { for _, ns := range nsGroup.GetNameServers() {

View File

@@ -44,6 +44,7 @@ import (
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
mgmt "github.com/netbirdio/netbird/management/client" mgmt "github.com/netbirdio/netbird/management/client"
"github.com/netbirdio/netbird/management/domain"
mgmtProto "github.com/netbirdio/netbird/management/proto" mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
@@ -77,7 +78,7 @@ var (
type MockWGIface struct { type MockWGIface struct {
CreateFunc func() error CreateFunc func() error
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error CreateOnAndroidFunc func(routeRange []string, ip string, domains domain.List) error
IsUserspaceBindFunc func() bool IsUserspaceBindFunc func() bool
NameFunc func() string NameFunc func() string
AddressFunc func() wgaddr.Address AddressFunc func() wgaddr.Address
@@ -111,7 +112,7 @@ func (m *MockWGIface) Create() error {
return m.CreateFunc() return m.CreateFunc()
} }
func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains []string) error { func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains domain.List) error {
return m.CreateOnAndroidFunc(routeRange, ip, domains) return m.CreateOnAndroidFunc(routeRange, ip, domains)
} }

View File

@@ -14,11 +14,12 @@ import (
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/management/domain"
) )
type wgIfaceBase interface { type wgIfaceBase interface {
Create() error Create() error
CreateOnAndroid(routeRange []string, ip string, domains []string) error CreateOnAndroid(routeRange []string, ip string, domains domain.List) error
IsUserspaceBind() bool IsUserspaceBind() bool
Name() string Name() string
Address() wgaddr.Address Address() wgaddr.Address

View File

@@ -229,15 +229,14 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
} }
if len(r.Answer) > 0 && len(r.Question) > 0 { if len(r.Answer) > 0 && len(r.Question) > 0 {
origPattern := "" var origPattern domain.Domain
if writer, ok := w.(*nbdns.ResponseWriterChain); ok { if writer, ok := w.(*nbdns.ResponseWriterChain); ok {
origPattern = writer.GetOrigPattern() origPattern = writer.GetOrigPattern()
} }
resolvedDomain := domain.Domain(strings.ToLower(r.Question[0].Name)) resolvedDomain := domain.Domain(strings.ToLower(r.Question[0].Name))
// already punycode via RegisterHandler() originalDomain := origPattern
originalDomain := domain.Domain(origPattern)
if originalDomain == "" { if originalDomain == "" {
originalDomain = resolvedDomain originalDomain = resolvedDomain
} }

View File

@@ -6,6 +6,8 @@ import (
"net/url" "net/url"
"strconv" "strconv"
"strings" "strings"
"github.com/netbirdio/netbird/management/domain"
) )
const ( const (
@@ -64,7 +66,7 @@ type NameServerGroup struct {
// Primary indicates that the nameserver group is the primary resolver for any dns query // Primary indicates that the nameserver group is the primary resolver for any dns query
Primary bool Primary bool
// Domains indicate the dns query domains to use with this nameserver group // Domains indicate the dns query domains to use with this nameserver group
Domains []string `gorm:"serializer:json"` Domains domain.List `gorm:"serializer:json"`
// Enabled group status // Enabled group status
Enabled bool Enabled bool
// SearchDomainsEnabled indicates whether to add match domains to search domains list or not // SearchDomainsEnabled indicates whether to add match domains to search domains list or not
@@ -142,7 +144,7 @@ func (g *NameServerGroup) Copy() *NameServerGroup {
Groups: make([]string, len(g.Groups)), Groups: make([]string, len(g.Groups)),
Enabled: g.Enabled, Enabled: g.Enabled,
Primary: g.Primary, Primary: g.Primary,
Domains: make([]string, len(g.Domains)), Domains: make(domain.List, len(g.Domains)),
SearchDomainsEnabled: g.SearchDomainsEnabled, SearchDomainsEnabled: g.SearchDomainsEnabled,
} }
@@ -188,7 +190,7 @@ func containsNameServer(element NameServer, list []NameServer) bool {
return false return false
} }
func compareGroupsList(list, other []string) bool { func compareGroupsList[T comparable](list, other []T) bool {
if len(list) != len(other) { if len(list) != len(other) {
return false return false
} }

View File

@@ -30,7 +30,7 @@ func (d Domain) SafeString() string {
} }
// PunycodeString returns the punycode representation of the Domain. // PunycodeString returns the punycode representation of the Domain.
// This should only be used if a punycode domain is expected but only a string is supported. // This should only be used if a punycode domain is expected but only a string is supported (e.g. an external library).
func (d Domain) PunycodeString() string { func (d Domain) PunycodeString() string {
return string(d) return string(d)
} }

View File

@@ -1,7 +1,7 @@
package domain package domain
import ( import (
"sort" "slices"
"strings" "strings"
) )
@@ -41,6 +41,7 @@ func (d List) ToSafeStringList() []string {
} }
// String converts List to a comma-separated string. // String converts List to a comma-separated string.
// This is useful for displaying domain names in a user-friendly format.
func (d List) String() (string, error) { func (d List) String() (string, error) {
list, err := d.ToStringList() list, err := d.ToStringList()
if err != nil { if err != nil {
@@ -50,7 +51,8 @@ func (d List) String() (string, error) {
} }
// SafeString converts List to a comma-separated non-punycode string. // SafeString converts List to a comma-separated non-punycode string.
// If a domain cannot be converted, the original string is used. // This is useful for displaying domain names in a user-friendly format.
// If a domain cannot be converted, the original (punycode) string is used.
func (d List) SafeString() string { func (d List) SafeString() string {
str, err := d.String() str, err := d.String()
if err != nil { if err != nil {
@@ -64,28 +66,22 @@ func (d List) PunycodeString() string {
return strings.Join(d.ToPunycodeList(), ", ") return strings.Join(d.ToPunycodeList(), ", ")
} }
// Equal checks if two domain lists are equal without considering the order.
func (d List) Equal(domains List) bool { func (d List) Equal(domains List) bool {
if len(d) != len(domains) { if len(d) != len(domains) {
return false return false
} }
sort.Slice(d, func(i, j int) bool { d1 := slices.Clone(d)
return d[i] < d[j] d2 := slices.Clone(domains)
})
sort.Slice(domains, func(i, j int) bool { slices.Sort(d1)
return domains[i] < domains[j] slices.Sort(d2)
})
for i, domain := range d { return slices.Equal(d1, d2)
if domain != domains[i] {
return false
}
}
return true
} }
// FromStringList creates a DomainList from a slice of string. // FromStringList creates a List from a slice of strings.
func FromStringList(s []string) (List, error) { func FromStringList(s []string) (List, error) {
var dl List var dl List
for _, domain := range s { for _, domain := range s {

View File

@@ -78,7 +78,7 @@ type Manager interface {
DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error
ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error)
GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains domain.List, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error
ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)

View File

@@ -19,6 +19,9 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server/idp"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
nbAccount "github.com/netbirdio/netbird/management/server/account" nbAccount "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
@@ -1688,7 +1691,7 @@ func TestAccount_Copy(t *testing.T) {
NameServerGroups: map[string]*nbdns.NameServerGroup{ NameServerGroups: map[string]*nbdns.NameServerGroup{
"nsGroup1": { "nsGroup1": {
ID: "nsGroup1", ID: "nsGroup1",
Domains: []string{}, Domains: domain.List{},
Groups: []string{}, Groups: []string{},
NameServers: []nbdns.NameServer{}, NameServers: []nbdns.NameServer{},
}, },

View File

@@ -258,7 +258,7 @@ func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup { func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
protoGroup := &proto.NameServerGroup{ protoGroup := &proto.NameServerGroup{
Primary: nsGroup.Primary, Primary: nsGroup.Primary,
Domains: nsGroup.Domains, Domains: nsGroup.Domains.ToPunycodeList(),
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled, SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)), NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
} }

View File

@@ -12,6 +12,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
@@ -367,7 +368,7 @@ func generateTestData(size int) nbdns.Config {
config.NameServerGroups[i] = &nbdns.NameServerGroup{ config.NameServerGroups[i] = &nbdns.NameServerGroup{
ID: fmt.Sprintf("group%d", i), ID: fmt.Sprintf("group%d", i),
Primary: i == 0, Primary: i == 0,
Domains: []string{fmt.Sprintf("domain%d.com", i)}, Domains: domain.List{domain.Domain(fmt.Sprintf("domain%d.com", i))},
SearchDomainsEnabled: true, SearchDomainsEnabled: true,
NameServers: []nbdns.NameServer{ NameServers: []nbdns.NameServer{
{ {
@@ -547,7 +548,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
Port: dns.DefaultDNSPort, Port: dns.DefaultDNSPort,
}}, }},
[]string{"groupB"}, []string{"groupB"},
true, []string{}, true, userID, false, true, domain.List{}, true, userID, false,
) )
assert.NoError(t, err) assert.NoError(t, err)
@@ -580,7 +581,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
Port: dns.DefaultDNSPort, Port: dns.DefaultDNSPort,
}}, }},
[]string{"groupA"}, []string{"groupA"},
true, []string{}, true, userID, false, true, domain.List{}, true, userID, false,
) )
assert.NoError(t, err) assert.NoError(t, err)

View File

@@ -9,6 +9,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
@@ -83,7 +84,13 @@ func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *htt
return return
} }
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), accountID, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userID, req.SearchDomainsEnabled) domains, err := domain.FromStringList(req.Domains)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains format"), w)
return
}
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), accountID, req.Name, req.Description, nsList, req.Groups, req.Primary, domains, req.Enabled, userID, req.SearchDomainsEnabled)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -123,12 +130,18 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt
return return
} }
domains, err := domain.FromStringList(req.Domains)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains format"), w)
return
}
updatedNSGroup := &nbdns.NameServerGroup{ updatedNSGroup := &nbdns.NameServerGroup{
ID: nsGroupID, ID: nsGroupID,
Name: req.Name, Name: req.Name,
Description: req.Description, Description: req.Description,
Primary: req.Primary, Primary: req.Primary,
Domains: req.Domains, Domains: domains,
NameServers: nsList, NameServers: nsList,
Groups: req.Groups, Groups: req.Groups,
Enabled: req.Enabled, Enabled: req.Enabled,
@@ -227,7 +240,7 @@ func toNameserverGroupResponse(serverNSGroup *nbdns.NameServerGroup) *api.Namese
Name: serverNSGroup.Name, Name: serverNSGroup.Name,
Description: serverNSGroup.Description, Description: serverNSGroup.Description,
Primary: serverNSGroup.Primary, Primary: serverNSGroup.Primary,
Domains: serverNSGroup.Domains, Domains: serverNSGroup.Domains.ToSafeStringList(),
Groups: serverNSGroup.Groups, Groups: serverNSGroup.Groups,
Nameservers: nsList, Nameservers: nsList,
Enabled: serverNSGroup.Enabled, Enabled: serverNSGroup.Enabled,

View File

@@ -10,17 +10,15 @@ import (
"net/netip" "net/netip"
"testing" "testing"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server/status"
"github.com/gorilla/mux"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/status"
) )
const ( const (
@@ -47,7 +45,7 @@ var baseExistingNSGroup = &nbdns.NameServerGroup{
}, },
}, },
Groups: []string{"testing"}, Groups: []string{"testing"},
Domains: []string{"domain"}, Domains: domain.List{"domain"},
Enabled: true, Enabled: true,
} }
@@ -60,7 +58,7 @@ func initNameserversTestData() *nameserversHandler {
} }
return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID) return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID)
}, },
CreateNameServerGroupFunc: func(_ context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, _ string, searchDomains bool) (*nbdns.NameServerGroup, error) { CreateNameServerGroupFunc: func(_ context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains domain.List, enabled bool, _ string, searchDomains bool) (*nbdns.NameServerGroup, error) {
return &nbdns.NameServerGroup{ return &nbdns.NameServerGroup{
ID: existingNSGroupID, ID: existingNSGroupID,
Name: name, Name: name,

View File

@@ -77,7 +77,7 @@ type MockAccountManager struct {
GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*types.PersonalAccessToken, error) GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*types.PersonalAccessToken, error)
GetAllPATsFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string) ([]*types.PersonalAccessToken, error) GetAllPATsFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string) ([]*types.PersonalAccessToken, error)
GetNameServerGroupFunc func(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) GetNameServerGroupFunc func(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroupFunc func(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) CreateNameServerGroupFunc func(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains domain.List, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
SaveNameServerGroupFunc func(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error SaveNameServerGroupFunc func(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error
ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
@@ -567,7 +567,7 @@ func (am *MockAccountManager) GetNameServerGroup(ctx context.Context, accountID,
} }
// CreateNameServerGroup mocks CreateNameServerGroup of the AccountManager interface // CreateNameServerGroup mocks CreateNameServerGroup of the AccountManager interface
func (am *MockAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) { func (am *MockAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains domain.List, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) {
if am.CreateNameServerGroupFunc != nil { if am.CreateNameServerGroupFunc != nil {
return am.CreateNameServerGroupFunc(ctx, accountID, name, description, nameServerList, groups, primary, domains, enabled, userID, searchDomainsEnabled) return am.CreateNameServerGroupFunc(ctx, accountID, name, description, nameServerList, groups, primary, domains, enabled, userID, searchDomainsEnabled)
} }

View File

@@ -10,6 +10,7 @@ import (
"github.com/rs/xid" "github.com/rs/xid"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
@@ -18,7 +19,7 @@ import (
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
) )
const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*[*.a-z]{1,}$` const domainPattern = `^(?i)[a-z0-9]+([\-]+[a-z0-9]+)*[*.a-z]{1,}$`
var invalidDomainName = errors.New("invalid domain name") var invalidDomainName = errors.New("invalid domain name")
@@ -36,7 +37,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
} }
// CreateNameServerGroup creates and saves a new nameserver group // CreateNameServerGroup creates and saves a new nameserver group
func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) { func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains domain.List, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
@@ -252,7 +253,7 @@ func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction store
return anyGroupHasPeersOrResources(ctx, transaction, oldNSGroup.AccountID, oldNSGroup.Groups) return anyGroupHasPeersOrResources(ctx, transaction, oldNSGroup.AccountID, oldNSGroup.Groups)
} }
func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error { func validateDomainInput(primary bool, domains domain.List, searchDomainsEnabled bool) error {
if !primary && len(domains) == 0 { if !primary && len(domains) == 0 {
return status.Errorf(status.InvalidArgument, "nameserver group primary status is false and domains are empty,"+ return status.Errorf(status.InvalidArgument, "nameserver group primary status is false and domains are empty,"+
" it should be primary or have at least one domain") " it should be primary or have at least one domain")
@@ -268,7 +269,7 @@ func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bo
} }
for _, domain := range domains { for _, domain := range domains {
if err := validateDomain(domain); err != nil { if err := validateDomain(domain.PunycodeString()); err != nil {
return status.Errorf(status.InvalidArgument, "nameserver group got an invalid domain: %s %q", domain, err) return status.Errorf(status.InvalidArgument, "nameserver group got an invalid domain: %s %q", domain, err)
} }
} }

View File

@@ -11,6 +11,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -41,7 +42,7 @@ func TestCreateNameServerGroup(t *testing.T) {
groups []string groups []string
nameServers []nbdns.NameServer nameServers []nbdns.NameServer
primary bool primary bool
domains []string domains domain.List
searchDomains bool searchDomains bool
} }
@@ -102,7 +103,7 @@ func TestCreateNameServerGroup(t *testing.T) {
description: "super", description: "super",
groups: []string{group1ID}, groups: []string{group1ID},
primary: false, primary: false,
domains: []string{validDomain}, domains: domain.List{validDomain},
nameServers: []nbdns.NameServer{ nameServers: []nbdns.NameServer{
{ {
IP: netip.MustParseAddr("1.1.1.1"), IP: netip.MustParseAddr("1.1.1.1"),
@@ -123,7 +124,7 @@ func TestCreateNameServerGroup(t *testing.T) {
Name: "super", Name: "super",
Description: "super", Description: "super",
Primary: false, Primary: false,
Domains: []string{"example.com"}, Domains: domain.List{"example.com"},
Groups: []string{group1ID}, Groups: []string{group1ID},
NameServers: []nbdns.NameServer{ NameServers: []nbdns.NameServer{
{ {
@@ -360,7 +361,7 @@ func TestCreateNameServerGroup(t *testing.T) {
name: "super", name: "super",
description: "super", description: "super",
groups: []string{group1ID}, groups: []string{group1ID},
domains: []string{invalidDomain}, domains: domain.List{invalidDomain},
nameServers: []nbdns.NameServer{ nameServers: []nbdns.NameServer{
{ {
IP: netip.MustParseAddr("1.1.1.1"), IP: netip.MustParseAddr("1.1.1.1"),
@@ -447,8 +448,8 @@ func TestSaveNameServerGroup(t *testing.T) {
validGroups := []string{group2ID} validGroups := []string{group2ID}
invalidGroups := []string{"nonExisting"} invalidGroups := []string{"nonExisting"}
disabledPrimary := false disabledPrimary := false
validDomains := []string{validDomain} validDomains := domain.List{validDomain}
invalidDomains := []string{invalidDomain} invalidDomains := domain.List{invalidDomain}
validNameServerList := []nbdns.NameServer{ validNameServerList := []nbdns.NameServer{
{ {
@@ -491,7 +492,7 @@ func TestSaveNameServerGroup(t *testing.T) {
newID *string newID *string
newName *string newName *string
newPrimary *bool newPrimary *bool
newDomains []string newDomains domain.List
newNSList []nbdns.NameServer newNSList []nbdns.NameServer
newGroups []string newGroups []string
skipCopying bool skipCopying bool
@@ -908,6 +909,11 @@ func TestValidateDomain(t *testing.T) {
domain: "example.", domain: "example.",
errFunc: require.NoError, errFunc: require.NoError,
}, },
{
name: "Valid domain name with double hyphen",
domain: "xn--bcher-kva.com",
errFunc: require.NoError,
},
{ {
name: "Invalid wildcard domain name", name: "Invalid wildcard domain name",
domain: "*.example", domain: "*.example",
@@ -924,8 +930,8 @@ func TestValidateDomain(t *testing.T) {
errFunc: require.Error, errFunc: require.Error,
}, },
{ {
name: "Invalid domain name with double hyphen", name: "Invalid domain name with double dot",
domain: "test--example.com", domain: "example..com",
errFunc: require.Error, errFunc: require.Error,
}, },
{ {
@@ -1009,7 +1015,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
Port: nbdns.DefaultDNSPort, Port: nbdns.DefaultDNSPort,
}}, }},
[]string{"groupA"}, []string{"groupA"},
true, []string{}, true, userID, false, true, domain.List{}, true, userID, false,
) )
assert.NoError(t, err) assert.NoError(t, err)
@@ -1054,7 +1060,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
Port: nbdns.DefaultDNSPort, Port: nbdns.DefaultDNSPort,
}}, }},
[]string{"groupB"}, []string{"groupB"},
true, []string{}, true, userID, false, true, domain.List{}, true, userID, false,
) )
assert.NoError(t, err) assert.NoError(t, err)

View File

@@ -1108,7 +1108,7 @@ func TestToSyncResponse(t *testing.T) {
Port: nbdns.DefaultDNSPort, Port: nbdns.DefaultDNSPort,
}}, }},
Primary: true, Primary: true,
Domains: []string{"example.com"}, Domains: domain.List{"example.com"},
Enabled: true, Enabled: true,
SearchDomainsEnabled: true, SearchDomainsEnabled: true,
}, },
@@ -1121,7 +1121,7 @@ func TestToSyncResponse(t *testing.T) {
}}, }},
Groups: []string{"group1"}, Groups: []string{"group1"},
Primary: true, Primary: true,
Domains: []string{"example.com"}, Domains: domain.List{"example.com"},
Enabled: true, Enabled: true,
SearchDomainsEnabled: true, SearchDomainsEnabled: true,
}, },
@@ -1995,7 +1995,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
Port: nbdns.DefaultDNSPort, Port: nbdns.DefaultDNSPort,
}}, }},
[]string{"groupC"}, []string{"groupC"},
true, []string{}, true, userID, false, true, domain.List{}, true, userID, false,
) )
require.NoError(t, err) require.NoError(t, err)