mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 00:36:38 +00:00
Compare commits
9 Commits
trigger-pr
...
add-ns-pun
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ff5eddf70b | ||
|
|
273160c682 | ||
|
|
1d6c360aec | ||
|
|
f04e7c3f06 | ||
|
|
3d89cd43c2 | ||
|
|
0eeda712d0 | ||
|
|
3e3268db5f | ||
|
|
31f0879e71 | ||
|
|
f25b5bb987 |
@@ -14,6 +14,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
|
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
|
||||||
@@ -43,11 +44,11 @@ func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) {
|
func (t *WGTunDevice) Create(routes []string, dns string, searchDomains domain.List) (WGConfigurer, error) {
|
||||||
log.Info("create tun interface")
|
log.Info("create tun interface")
|
||||||
|
|
||||||
routesString := routesToString(routes)
|
routesString := routesToString(routes)
|
||||||
searchDomainsToString := searchDomainsToString(searchDomains)
|
searchDomainsToString := searchDomainsToString(searchDomains.ToPunycodeList())
|
||||||
|
|
||||||
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
|
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -8,10 +8,11 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
type WGTunDevice interface {
|
type WGTunDevice interface {
|
||||||
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
|
Create(routes []string, dns string, searchDomains domain.List) (device.WGConfigurer, error)
|
||||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||||
UpdateAddr(address wgaddr.Address) error
|
UpdateAddr(address wgaddr.Address) error
|
||||||
WgAddress() wgaddr.Address
|
WgAddress() wgaddr.Address
|
||||||
|
|||||||
@@ -2,7 +2,11 @@
|
|||||||
|
|
||||||
package iface
|
package iface
|
||||||
|
|
||||||
import "fmt"
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
|
)
|
||||||
|
|
||||||
// Create creates a new Wireguard interface, sets a given IP and brings it up.
|
// Create creates a new Wireguard interface, sets a given IP and brings it up.
|
||||||
// Will reuse an existing one.
|
// Will reuse an existing one.
|
||||||
@@ -21,6 +25,6 @@ func (w *WGIface) Create() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreateOnAndroid this function make sense on mobile only
|
// CreateOnAndroid this function make sense on mobile only
|
||||||
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
func (w *WGIface) CreateOnAndroid([]string, string, domain.List) error {
|
||||||
return fmt.Errorf("this function has not implemented on non mobile")
|
return fmt.Errorf("this function has not implemented on non mobile")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,11 +2,13 @@ package iface
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
|
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
|
||||||
// Will reuse an existing one.
|
// Will reuse an existing one.
|
||||||
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error {
|
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains domain.List) error {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Create creates a new Wireguard interface, sets a given IP and brings it up.
|
// Create creates a new Wireguard interface, sets a given IP and brings it up.
|
||||||
@@ -36,6 +38,6 @@ func (w *WGIface) Create() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreateOnAndroid this function make sense on mobile only
|
// CreateOnAndroid this function make sense on mobile only
|
||||||
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
func (w *WGIface) CreateOnAndroid([]string, string, domain.List) error {
|
||||||
return fmt.Errorf("this function has not implemented on this platform")
|
return fmt.Errorf("this function has not implemented on this platform")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -239,7 +239,7 @@ func searchDomains(config HostDNSConfig) []string {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
listOfDomains = append(listOfDomains, strings.TrimSuffix(dConf.Domain, "."))
|
listOfDomains = append(listOfDomains, strings.TrimSuffix(dConf.Domain.PunycodeString(), "."))
|
||||||
}
|
}
|
||||||
return listOfDomains
|
return listOfDomains
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ import (
|
|||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -24,8 +26,8 @@ type SubdomainMatcher interface {
|
|||||||
type HandlerEntry struct {
|
type HandlerEntry struct {
|
||||||
Handler dns.Handler
|
Handler dns.Handler
|
||||||
Priority int
|
Priority int
|
||||||
Pattern string
|
Pattern domain.Domain
|
||||||
OrigPattern string
|
OrigPattern domain.Domain
|
||||||
IsWildcard bool
|
IsWildcard bool
|
||||||
MatchSubdomains bool
|
MatchSubdomains bool
|
||||||
}
|
}
|
||||||
@@ -39,7 +41,7 @@ type HandlerChain struct {
|
|||||||
// ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain
|
// ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain
|
||||||
type ResponseWriterChain struct {
|
type ResponseWriterChain struct {
|
||||||
dns.ResponseWriter
|
dns.ResponseWriter
|
||||||
origPattern string
|
origPattern domain.Domain
|
||||||
shouldContinue bool
|
shouldContinue bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -59,18 +61,18 @@ func NewHandlerChain() *HandlerChain {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetOrigPattern returns the original pattern of the handler that wrote the response
|
// GetOrigPattern returns the original pattern of the handler that wrote the response
|
||||||
func (w *ResponseWriterChain) GetOrigPattern() string {
|
func (w *ResponseWriterChain) GetOrigPattern() domain.Domain {
|
||||||
return w.origPattern
|
return w.origPattern
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
|
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
|
||||||
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int) {
|
func (c *HandlerChain) AddHandler(pattern domain.Domain, handler dns.Handler, priority int) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
pattern = strings.ToLower(dns.Fqdn(pattern))
|
pattern = domain.Domain(strings.ToLower(dns.Fqdn(pattern.PunycodeString())))
|
||||||
origPattern := pattern
|
origPattern := pattern
|
||||||
isWildcard := strings.HasPrefix(pattern, "*.")
|
isWildcard := strings.HasPrefix(pattern.PunycodeString(), "*.")
|
||||||
if isWildcard {
|
if isWildcard {
|
||||||
pattern = pattern[2:]
|
pattern = pattern[2:]
|
||||||
}
|
}
|
||||||
@@ -110,8 +112,8 @@ func (c *HandlerChain) findHandlerPosition(newEntry HandlerEntry) int {
|
|||||||
|
|
||||||
// domain specificity next
|
// domain specificity next
|
||||||
if h.Priority == newEntry.Priority {
|
if h.Priority == newEntry.Priority {
|
||||||
newDots := strings.Count(newEntry.Pattern, ".")
|
newDots := strings.Count(newEntry.Pattern.PunycodeString(), ".")
|
||||||
existingDots := strings.Count(h.Pattern, ".")
|
existingDots := strings.Count(h.Pattern.PunycodeString(), ".")
|
||||||
if newDots > existingDots {
|
if newDots > existingDots {
|
||||||
return i
|
return i
|
||||||
}
|
}
|
||||||
@@ -123,20 +125,20 @@ func (c *HandlerChain) findHandlerPosition(newEntry HandlerEntry) int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RemoveHandler removes a handler for the given pattern and priority
|
// RemoveHandler removes a handler for the given pattern and priority
|
||||||
func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
|
func (c *HandlerChain) RemoveHandler(pattern domain.Domain, priority int) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
pattern = dns.Fqdn(pattern)
|
pattern = domain.Domain(dns.Fqdn(pattern.PunycodeString()))
|
||||||
|
|
||||||
c.removeEntry(pattern, priority)
|
c.removeEntry(pattern, priority)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *HandlerChain) removeEntry(pattern string, priority int) {
|
func (c *HandlerChain) removeEntry(pattern domain.Domain, priority int) {
|
||||||
// Find and remove handlers matching both original pattern (case-insensitive) and priority
|
// Find and remove handlers matching both original pattern (case-insensitive) and priority
|
||||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||||
entry := c.handlers[i]
|
entry := c.handlers[i]
|
||||||
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
if strings.EqualFold(entry.OrigPattern.PunycodeString(), pattern.PunycodeString()) && entry.Priority == priority {
|
||||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -201,16 +203,16 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
|||||||
case entry.Pattern == ".":
|
case entry.Pattern == ".":
|
||||||
return true
|
return true
|
||||||
case entry.IsWildcard:
|
case entry.IsWildcard:
|
||||||
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
|
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern.PunycodeString()), ".")
|
||||||
return len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
|
return len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern.PunycodeString())
|
||||||
default:
|
default:
|
||||||
// For non-wildcard patterns:
|
// For non-wildcard patterns:
|
||||||
// If handler wants subdomain matching, allow suffix match
|
// If handler wants subdomain matching, allow suffix match
|
||||||
// Otherwise require exact match
|
// Otherwise require exact match
|
||||||
if entry.MatchSubdomains {
|
if entry.MatchSubdomains {
|
||||||
return strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern)
|
return strings.EqualFold(qname, entry.Pattern.PunycodeString()) || strings.HasSuffix(qname, "."+entry.Pattern.PunycodeString())
|
||||||
} else {
|
} else {
|
||||||
return strings.EqualFold(qname, entry.Pattern)
|
return strings.EqualFold(qname, entry.Pattern.PunycodeString())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
|
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
|
||||||
@@ -50,8 +51,8 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
|
|||||||
func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
handlerDomain string
|
handlerDomain domain.Domain
|
||||||
queryDomain string
|
queryDomain domain.Domain
|
||||||
isWildcard bool
|
isWildcard bool
|
||||||
matchSubdomains bool
|
matchSubdomains bool
|
||||||
shouldMatch bool
|
shouldMatch bool
|
||||||
@@ -141,7 +142,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
|||||||
chain.AddHandler(pattern, handler, nbdns.PriorityDefault)
|
chain.AddHandler(pattern, handler, nbdns.PriorityDefault)
|
||||||
|
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
r.SetQuestion(tt.queryDomain.PunycodeString(), dns.TypeA)
|
||||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
|
|
||||||
chain.ServeDNS(w, r)
|
chain.ServeDNS(w, r)
|
||||||
@@ -160,17 +161,17 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
|||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
handlers []struct {
|
handlers []struct {
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
}
|
}
|
||||||
queryDomain string
|
queryDomain domain.Domain
|
||||||
expectedCalls int
|
expectedCalls int
|
||||||
expectedHandler int // index of the handler that should be called
|
expectedHandler int // index of the handler that should be called
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "wildcard and exact same priority - exact should win",
|
name: "wildcard and exact same priority - exact should win",
|
||||||
handlers: []struct {
|
handlers: []struct {
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
}{
|
}{
|
||||||
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
||||||
@@ -183,7 +184,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "higher priority wildcard over lower priority exact",
|
name: "higher priority wildcard over lower priority exact",
|
||||||
handlers: []struct {
|
handlers: []struct {
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
}{
|
}{
|
||||||
{pattern: "example.com.", priority: nbdns.PriorityDefault},
|
{pattern: "example.com.", priority: nbdns.PriorityDefault},
|
||||||
@@ -196,7 +197,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "multiple wildcards different priorities",
|
name: "multiple wildcards different priorities",
|
||||||
handlers: []struct {
|
handlers: []struct {
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
}{
|
}{
|
||||||
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
||||||
@@ -210,7 +211,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "subdomain with mix of patterns",
|
name: "subdomain with mix of patterns",
|
||||||
handlers: []struct {
|
handlers: []struct {
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
}{
|
}{
|
||||||
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
||||||
@@ -224,7 +225,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "root zone with specific domain",
|
name: "root zone with specific domain",
|
||||||
handlers: []struct {
|
handlers: []struct {
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
}{
|
}{
|
||||||
{pattern: ".", priority: nbdns.PriorityDefault},
|
{pattern: ".", priority: nbdns.PriorityDefault},
|
||||||
@@ -258,7 +259,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
|||||||
|
|
||||||
// Create and execute request
|
// Create and execute request
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
r.SetQuestion(tt.queryDomain.PunycodeString(), dns.TypeA)
|
||||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
chain.ServeDNS(w, r)
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
@@ -330,7 +331,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
ops []struct {
|
ops []struct {
|
||||||
action string // "add" or "remove"
|
action string // "add" or "remove"
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
}
|
}
|
||||||
query string
|
query string
|
||||||
@@ -340,7 +341,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
|||||||
name: "remove high priority keeps lower priority handler",
|
name: "remove high priority keeps lower priority handler",
|
||||||
ops: []struct {
|
ops: []struct {
|
||||||
action string
|
action string
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
}{
|
}{
|
||||||
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
||||||
@@ -357,7 +358,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
|||||||
name: "remove lower priority keeps high priority handler",
|
name: "remove lower priority keeps high priority handler",
|
||||||
ops: []struct {
|
ops: []struct {
|
||||||
action string
|
action string
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
}{
|
}{
|
||||||
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
||||||
@@ -374,7 +375,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
|||||||
name: "remove all handlers in order",
|
name: "remove all handlers in order",
|
||||||
ops: []struct {
|
ops: []struct {
|
||||||
action string
|
action string
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
}{
|
}{
|
||||||
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
||||||
@@ -436,7 +437,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
|||||||
func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
||||||
chain := nbdns.NewHandlerChain()
|
chain := nbdns.NewHandlerChain()
|
||||||
|
|
||||||
testDomain := "example.com."
|
testDomain := domain.Domain("example.com.")
|
||||||
testQuery := "test.example.com."
|
testQuery := "test.example.com."
|
||||||
|
|
||||||
// Create handlers with MatchSubdomains enabled
|
// Create handlers with MatchSubdomains enabled
|
||||||
@@ -518,7 +519,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
scenario string
|
scenario string
|
||||||
addHandlers []struct {
|
addHandlers []struct {
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
subdomains bool
|
subdomains bool
|
||||||
shouldMatch bool
|
shouldMatch bool
|
||||||
@@ -530,7 +531,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
|||||||
name: "case insensitive exact match",
|
name: "case insensitive exact match",
|
||||||
scenario: "handler registered lowercase, query uppercase",
|
scenario: "handler registered lowercase, query uppercase",
|
||||||
addHandlers: []struct {
|
addHandlers: []struct {
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
subdomains bool
|
subdomains bool
|
||||||
shouldMatch bool
|
shouldMatch bool
|
||||||
@@ -544,7 +545,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
|||||||
name: "case insensitive wildcard match",
|
name: "case insensitive wildcard match",
|
||||||
scenario: "handler registered mixed case wildcard, query different case",
|
scenario: "handler registered mixed case wildcard, query different case",
|
||||||
addHandlers: []struct {
|
addHandlers: []struct {
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
subdomains bool
|
subdomains bool
|
||||||
shouldMatch bool
|
shouldMatch bool
|
||||||
@@ -558,7 +559,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
|||||||
name: "multiple handlers different case same domain",
|
name: "multiple handlers different case same domain",
|
||||||
scenario: "second handler should replace first despite case difference",
|
scenario: "second handler should replace first despite case difference",
|
||||||
addHandlers: []struct {
|
addHandlers: []struct {
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
subdomains bool
|
subdomains bool
|
||||||
shouldMatch bool
|
shouldMatch bool
|
||||||
@@ -573,7 +574,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
|||||||
name: "subdomain matching case insensitive",
|
name: "subdomain matching case insensitive",
|
||||||
scenario: "handler with MatchSubdomains true should match regardless of case",
|
scenario: "handler with MatchSubdomains true should match regardless of case",
|
||||||
addHandlers: []struct {
|
addHandlers: []struct {
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
subdomains bool
|
subdomains bool
|
||||||
shouldMatch bool
|
shouldMatch bool
|
||||||
@@ -587,7 +588,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
|||||||
name: "root zone case insensitive",
|
name: "root zone case insensitive",
|
||||||
scenario: "root zone handler should match regardless of case",
|
scenario: "root zone handler should match regardless of case",
|
||||||
addHandlers: []struct {
|
addHandlers: []struct {
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
subdomains bool
|
subdomains bool
|
||||||
shouldMatch bool
|
shouldMatch bool
|
||||||
@@ -601,7 +602,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
|||||||
name: "multiple handlers different priority",
|
name: "multiple handlers different priority",
|
||||||
scenario: "should call higher priority handler despite case differences",
|
scenario: "should call higher priority handler despite case differences",
|
||||||
addHandlers: []struct {
|
addHandlers: []struct {
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
subdomains bool
|
subdomains bool
|
||||||
shouldMatch bool
|
shouldMatch bool
|
||||||
@@ -618,7 +619,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
chain := nbdns.NewHandlerChain()
|
chain := nbdns.NewHandlerChain()
|
||||||
handlerCalls := make(map[string]bool) // track which patterns were called
|
handlerCalls := make(map[domain.Domain]bool) // track which patterns were called
|
||||||
|
|
||||||
// Add handlers according to test case
|
// Add handlers according to test case
|
||||||
for _, h := range tt.addHandlers {
|
for _, h := range tt.addHandlers {
|
||||||
@@ -686,19 +687,19 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
scenario string
|
scenario string
|
||||||
ops []struct {
|
ops []struct {
|
||||||
action string
|
action string
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
subdomain bool
|
subdomain bool
|
||||||
}
|
}
|
||||||
query string
|
query domain.Domain
|
||||||
expectedMatch string
|
expectedMatch domain.Domain
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "more specific domain matches first",
|
name: "more specific domain matches first",
|
||||||
scenario: "sub.example.com should match before example.com",
|
scenario: "sub.example.com should match before example.com",
|
||||||
ops: []struct {
|
ops: []struct {
|
||||||
action string
|
action string
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
subdomain bool
|
subdomain bool
|
||||||
}{
|
}{
|
||||||
@@ -713,7 +714,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
scenario: "sub.example.com should match before example.com",
|
scenario: "sub.example.com should match before example.com",
|
||||||
ops: []struct {
|
ops: []struct {
|
||||||
action string
|
action string
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
subdomain bool
|
subdomain bool
|
||||||
}{
|
}{
|
||||||
@@ -728,7 +729,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
scenario: "after removing most specific, should fall back to less specific",
|
scenario: "after removing most specific, should fall back to less specific",
|
||||||
ops: []struct {
|
ops: []struct {
|
||||||
action string
|
action string
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
subdomain bool
|
subdomain bool
|
||||||
}{
|
}{
|
||||||
@@ -745,7 +746,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
scenario: "less specific domain with higher priority should match first",
|
scenario: "less specific domain with higher priority should match first",
|
||||||
ops: []struct {
|
ops: []struct {
|
||||||
action string
|
action string
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
subdomain bool
|
subdomain bool
|
||||||
}{
|
}{
|
||||||
@@ -760,7 +761,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
scenario: "with equal priority, more specific domain should match",
|
scenario: "with equal priority, more specific domain should match",
|
||||||
ops: []struct {
|
ops: []struct {
|
||||||
action string
|
action string
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
subdomain bool
|
subdomain bool
|
||||||
}{
|
}{
|
||||||
@@ -776,7 +777,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
scenario: "specific domain should match before wildcard at same priority",
|
scenario: "specific domain should match before wildcard at same priority",
|
||||||
ops: []struct {
|
ops: []struct {
|
||||||
action string
|
action string
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
subdomain bool
|
subdomain bool
|
||||||
}{
|
}{
|
||||||
@@ -791,7 +792,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
chain := nbdns.NewHandlerChain()
|
chain := nbdns.NewHandlerChain()
|
||||||
handlers := make(map[string]*nbdns.MockSubdomainHandler)
|
handlers := make(map[domain.Domain]*nbdns.MockSubdomainHandler)
|
||||||
|
|
||||||
for _, op := range tt.ops {
|
for _, op := range tt.ops {
|
||||||
if op.action == "add" {
|
if op.action == "add" {
|
||||||
@@ -804,7 +805,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
r.SetQuestion(tt.query, dns.TypeA)
|
r.SetQuestion(tt.query.PunycodeString(), dns.TypeA)
|
||||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
|
|
||||||
// Setup handler expectations
|
// Setup handler expectations
|
||||||
@@ -836,9 +837,9 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
|
func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
addPattern string
|
addPattern domain.Domain
|
||||||
removePattern string
|
removePattern domain.Domain
|
||||||
queryPattern string
|
queryPattern domain.Domain
|
||||||
shouldBeRemoved bool
|
shouldBeRemoved bool
|
||||||
description string
|
description string
|
||||||
}{
|
}{
|
||||||
@@ -954,7 +955,7 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
|
|||||||
|
|
||||||
handler := &nbdns.MockHandler{}
|
handler := &nbdns.MockHandler{}
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
r.SetQuestion(tt.queryPattern, dns.TypeA)
|
r.SetQuestion(tt.queryPattern.PunycodeString(), dns.TypeA)
|
||||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
|
|
||||||
// First verify no handler is called before adding any
|
// First verify no handler is called before adding any
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrRouteAllWithoutNameserverGroup = fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
|
var ErrRouteAllWithoutNameserverGroup = fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
|
||||||
@@ -40,7 +41,7 @@ type HostDNSConfig struct {
|
|||||||
|
|
||||||
type DomainConfig struct {
|
type DomainConfig struct {
|
||||||
Disabled bool `json:"disabled"`
|
Disabled bool `json:"disabled"`
|
||||||
Domain string `json:"domain"`
|
Domain domain.Domain `json:"domain"`
|
||||||
MatchOnly bool `json:"matchOnly"`
|
MatchOnly bool `json:"matchOnly"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -103,18 +104,20 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
|
|||||||
config.RouteAll = true
|
config.RouteAll = true
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, domain := range nsConfig.Domains {
|
for _, d := range nsConfig.Domains {
|
||||||
|
d := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||||
config.Domains = append(config.Domains, DomainConfig{
|
config.Domains = append(config.Domains, DomainConfig{
|
||||||
Domain: strings.ToLower(dns.Fqdn(domain)),
|
Domain: domain.Domain(d),
|
||||||
MatchOnly: !nsConfig.SearchDomainsEnabled,
|
MatchOnly: !nsConfig.SearchDomainsEnabled,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, customZone := range dnsConfig.CustomZones {
|
for _, customZone := range dnsConfig.CustomZones {
|
||||||
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone)
|
d := strings.ToLower(dns.Fqdn(customZone.Domain))
|
||||||
|
matchOnly := strings.HasSuffix(d, ipv4ReverseZone) || strings.HasSuffix(d, ipv6ReverseZone)
|
||||||
config.Domains = append(config.Domains, DomainConfig{
|
config.Domains = append(config.Domains, DomainConfig{
|
||||||
Domain: strings.ToLower(dns.Fqdn(customZone.Domain)),
|
Domain: domain.Domain(d),
|
||||||
MatchOnly: matchOnly,
|
MatchOnly: matchOnly,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -79,10 +79,10 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if dConf.MatchOnly {
|
if dConf.MatchOnly {
|
||||||
matchDomains = append(matchDomains, strings.TrimSuffix(dConf.Domain, "."))
|
matchDomains = append(matchDomains, strings.TrimSuffix(dConf.Domain.PunycodeString(), "."))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain, "."))
|
searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain.PunycodeString(), "."))
|
||||||
}
|
}
|
||||||
|
|
||||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||||
|
|||||||
@@ -186,9 +186,9 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if !dConf.MatchOnly {
|
if !dConf.MatchOnly {
|
||||||
searchDomains = append(searchDomains, strings.TrimSuffix(dConf.Domain, "."))
|
searchDomains = append(searchDomains, strings.TrimSuffix(dConf.Domain.PunycodeString(), "."))
|
||||||
}
|
}
|
||||||
matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, "."))
|
matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain.PunycodeString(), "."))
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(matchDomains) != 0 {
|
if len(matchDomains) != 0 {
|
||||||
|
|||||||
@@ -62,8 +62,8 @@ func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
|||||||
return fmt.Errorf("method UpdateDNSServer is not implemented")
|
return fmt.Errorf("method UpdateDNSServer is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockServer) SearchDomains() []string {
|
func (m *MockServer) SearchDomains() domain.List {
|
||||||
return make([]string, 0)
|
return make(domain.List, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProbeAvailability mocks implementation of ProbeAvailability from the Server interface
|
// ProbeAvailability mocks implementation of ProbeAvailability from the Server interface
|
||||||
|
|||||||
@@ -125,10 +125,10 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, st
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if dConf.MatchOnly {
|
if dConf.MatchOnly {
|
||||||
matchDomains = append(matchDomains, "~."+dConf.Domain)
|
matchDomains = append(matchDomains, "~."+dConf.Domain.PunycodeString())
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
searchDomains = append(searchDomains, dConf.Domain)
|
searchDomains = append(searchDomains, dConf.Domain.PunycodeString())
|
||||||
}
|
}
|
||||||
|
|
||||||
newDomainList := append(searchDomains, matchDomains...) //nolint:gocritic
|
newDomainList := append(searchDomains, matchDomains...) //nolint:gocritic
|
||||||
|
|||||||
@@ -1,21 +1,19 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"reflect"
|
|
||||||
"sort"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
type notifier struct {
|
type notifier struct {
|
||||||
listener listener.NetworkChangeListener
|
listener listener.NetworkChangeListener
|
||||||
listenerMux sync.Mutex
|
listenerMux sync.Mutex
|
||||||
searchDomains []string
|
searchDomains domain.List
|
||||||
}
|
}
|
||||||
|
|
||||||
func newNotifier(initialSearchDomains []string) *notifier {
|
func newNotifier(initialSearchDomains domain.List) *notifier {
|
||||||
sort.Strings(initialSearchDomains)
|
|
||||||
return ¬ifier{
|
return ¬ifier{
|
||||||
searchDomains: initialSearchDomains,
|
searchDomains: initialSearchDomains,
|
||||||
}
|
}
|
||||||
@@ -27,16 +25,8 @@ func (n *notifier) setListener(listener listener.NetworkChangeListener) {
|
|||||||
n.listener = listener
|
n.listener = listener
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *notifier) onNewSearchDomains(searchDomains []string) {
|
func (n *notifier) onNewSearchDomains(searchDomains domain.List) {
|
||||||
sort.Strings(searchDomains)
|
if searchDomains.Equal(n.searchDomains) {
|
||||||
|
|
||||||
if len(n.searchDomains) != len(searchDomains) {
|
|
||||||
n.searchDomains = searchDomains
|
|
||||||
n.notify()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if reflect.DeepEqual(n.searchDomains, searchDomains) {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -44,12 +44,12 @@ type Server interface {
|
|||||||
DnsIP() string
|
DnsIP() string
|
||||||
UpdateDNSServer(serial uint64, update nbdns.Config) error
|
UpdateDNSServer(serial uint64, update nbdns.Config) error
|
||||||
OnUpdatedHostDNSServer(strings []string)
|
OnUpdatedHostDNSServer(strings []string)
|
||||||
SearchDomains() []string
|
SearchDomains() domain.List
|
||||||
ProbeAvailability()
|
ProbeAvailability()
|
||||||
}
|
}
|
||||||
|
|
||||||
type nsGroupsByDomain struct {
|
type nsGroupsByDomain struct {
|
||||||
domain string
|
domain domain.Domain
|
||||||
groups []*nbdns.NameServerGroup
|
groups []*nbdns.NameServerGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -90,7 +90,7 @@ type handlerWithStop interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type handlerWrapper struct {
|
type handlerWrapper struct {
|
||||||
domain string
|
domain domain.Domain
|
||||||
handler handlerWithStop
|
handler handlerWithStop
|
||||||
priority int
|
priority int
|
||||||
}
|
}
|
||||||
@@ -197,7 +197,7 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler
|
|||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
s.registerHandler(domains.ToPunycodeList(), handler, priority)
|
s.registerHandler(domains, handler, priority)
|
||||||
|
|
||||||
// TODO: This will take over zones for non-wildcard domains, for which we might not have a handler in the chain
|
// TODO: This will take over zones for non-wildcard domains, for which we might not have a handler in the chain
|
||||||
for _, domain := range domains {
|
for _, domain := range domains {
|
||||||
@@ -207,7 +207,7 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler
|
|||||||
s.applyHostConfig()
|
s.applyHostConfig()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
|
func (s *DefaultServer) registerHandler(domains domain.List, handler dns.Handler, priority int) {
|
||||||
log.Debugf("registering handler %s with priority %d", handler, priority)
|
log.Debugf("registering handler %s with priority %d", handler, priority)
|
||||||
|
|
||||||
for _, domain := range domains {
|
for _, domain := range domains {
|
||||||
@@ -224,7 +224,7 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
|
|||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
s.deregisterHandler(domains.ToPunycodeList(), priority)
|
s.deregisterHandler(domains, priority)
|
||||||
for _, domain := range domains {
|
for _, domain := range domains {
|
||||||
zone := toZone(domain)
|
zone := toZone(domain)
|
||||||
s.extraDomains[zone]--
|
s.extraDomains[zone]--
|
||||||
@@ -235,7 +235,7 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
|
|||||||
s.applyHostConfig()
|
s.applyHostConfig()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
func (s *DefaultServer) deregisterHandler(domains domain.List, priority int) {
|
||||||
log.Debugf("deregistering handler %v with priority %d", domains, priority)
|
log.Debugf("deregistering handler %v with priority %d", domains, priority)
|
||||||
|
|
||||||
for _, domain := range domains {
|
for _, domain := range domains {
|
||||||
@@ -378,8 +378,8 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) SearchDomains() []string {
|
func (s *DefaultServer) SearchDomains() domain.List {
|
||||||
var searchDomains []string
|
var searchDomains domain.List
|
||||||
|
|
||||||
for _, dConf := range s.currentConfig.Domains {
|
for _, dConf := range s.currentConfig.Domains {
|
||||||
if dConf.Disabled {
|
if dConf.Disabled {
|
||||||
@@ -472,18 +472,16 @@ func (s *DefaultServer) applyHostConfig() {
|
|||||||
|
|
||||||
config := s.currentConfig
|
config := s.currentConfig
|
||||||
|
|
||||||
existingDomains := make(map[string]struct{})
|
existingDomains := make(map[domain.Domain]struct{})
|
||||||
for _, d := range config.Domains {
|
for _, d := range config.Domains {
|
||||||
existingDomains[d.Domain] = struct{}{}
|
existingDomains[d.Domain] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// add extra domains only if they're not already in the config
|
// add extra domains only if they're not already in the config
|
||||||
for domain := range s.extraDomains {
|
for d := range s.extraDomains {
|
||||||
domainStr := domain.PunycodeString()
|
if _, exists := existingDomains[d]; !exists {
|
||||||
|
|
||||||
if _, exists := existingDomains[domainStr]; !exists {
|
|
||||||
config.Domains = append(config.Domains, DomainConfig{
|
config.Domains = append(config.Domains, DomainConfig{
|
||||||
Domain: domainStr,
|
Domain: d,
|
||||||
MatchOnly: true,
|
MatchOnly: true,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -525,7 +523,7 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
|
|||||||
}
|
}
|
||||||
|
|
||||||
muxUpdates = append(muxUpdates, handlerWrapper{
|
muxUpdates = append(muxUpdates, handlerWrapper{
|
||||||
domain: customZone.Domain,
|
domain: domain.Domain(customZone.Domain),
|
||||||
handler: s.localResolver,
|
handler: s.localResolver,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityMatchDomain,
|
||||||
})
|
})
|
||||||
@@ -647,7 +645,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
|
|||||||
func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
||||||
// this will introduce a short period of time when the server is not able to handle DNS requests
|
// this will introduce a short period of time when the server is not able to handle DNS requests
|
||||||
for _, existing := range s.dnsMuxMap {
|
for _, existing := range s.dnsMuxMap {
|
||||||
s.deregisterHandler([]string{existing.domain}, existing.priority)
|
s.deregisterHandler(domain.List{existing.domain}, existing.priority)
|
||||||
existing.handler.Stop()
|
existing.handler.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -658,7 +656,7 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
|||||||
if update.domain == nbdns.RootZone {
|
if update.domain == nbdns.RootZone {
|
||||||
containsRootUpdate = true
|
containsRootUpdate = true
|
||||||
}
|
}
|
||||||
s.registerHandler([]string{update.domain}, update.handler, update.priority)
|
s.registerHandler(domain.List{update.domain}, update.handler, update.priority)
|
||||||
muxUpdateMap[update.handler.ID()] = update
|
muxUpdateMap[update.handler.ID()] = update
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -687,7 +685,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
handler dns.Handler,
|
handler dns.Handler,
|
||||||
priority int,
|
priority int,
|
||||||
) (deactivate func(error), reactivate func()) {
|
) (deactivate func(error), reactivate func()) {
|
||||||
var removeIndex map[string]int
|
var removeIndex map[domain.Domain]int
|
||||||
deactivate = func(err error) {
|
deactivate = func(err error) {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
@@ -695,20 +693,20 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||||
l.Info("Temporarily deactivating nameservers group due to timeout")
|
l.Info("Temporarily deactivating nameservers group due to timeout")
|
||||||
|
|
||||||
removeIndex = make(map[string]int)
|
removeIndex = make(map[domain.Domain]int)
|
||||||
for _, domain := range nsGroup.Domains {
|
for _, domain := range nsGroup.Domains {
|
||||||
removeIndex[domain] = -1
|
removeIndex[domain] = -1
|
||||||
}
|
}
|
||||||
if nsGroup.Primary {
|
if nsGroup.Primary {
|
||||||
removeIndex[nbdns.RootZone] = -1
|
removeIndex[nbdns.RootZone] = -1
|
||||||
s.currentConfig.RouteAll = false
|
s.currentConfig.RouteAll = false
|
||||||
s.deregisterHandler([]string{nbdns.RootZone}, priority)
|
s.deregisterHandler(domain.List{nbdns.RootZone}, priority)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, item := range s.currentConfig.Domains {
|
for i, item := range s.currentConfig.Domains {
|
||||||
if _, found := removeIndex[item.Domain]; found {
|
if _, found := removeIndex[item.Domain]; found {
|
||||||
s.currentConfig.Domains[i].Disabled = true
|
s.currentConfig.Domains[i].Disabled = true
|
||||||
s.deregisterHandler([]string{item.Domain}, priority)
|
s.deregisterHandler(domain.List{item.Domain}, priority)
|
||||||
removeIndex[item.Domain] = i
|
removeIndex[item.Domain] = i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -732,12 +730,12 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
for domain, i := range removeIndex {
|
for d, i := range removeIndex {
|
||||||
if i == -1 || i >= len(s.currentConfig.Domains) || s.currentConfig.Domains[i].Domain != domain {
|
if i == -1 || i >= len(s.currentConfig.Domains) || s.currentConfig.Domains[i].Domain != d{
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
s.currentConfig.Domains[i].Disabled = false
|
s.currentConfig.Domains[i].Disabled = false
|
||||||
s.registerHandler([]string{domain}, handler, priority)
|
s.registerHandler(domain.List{d}, handler, priority)
|
||||||
}
|
}
|
||||||
|
|
||||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||||
@@ -745,7 +743,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
|
|
||||||
if nsGroup.Primary {
|
if nsGroup.Primary {
|
||||||
s.currentConfig.RouteAll = true
|
s.currentConfig.RouteAll = true
|
||||||
s.registerHandler([]string{nbdns.RootZone}, handler, priority)
|
s.registerHandler(domain.List{nbdns.RootZone}, handler, priority)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.applyHostConfig()
|
s.applyHostConfig()
|
||||||
@@ -777,7 +775,7 @@ func (s *DefaultServer) addHostRootZone() {
|
|||||||
handler.deactivate = func(error) {}
|
handler.deactivate = func(error) {}
|
||||||
handler.reactivate = func() {}
|
handler.reactivate = func() {}
|
||||||
|
|
||||||
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault)
|
s.registerHandler(domain.List{nbdns.RootZone}, handler, PriorityDefault)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
|
func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
|
||||||
@@ -792,7 +790,7 @@ func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
|
|||||||
state := peer.NSGroupState{
|
state := peer.NSGroupState{
|
||||||
ID: generateGroupKey(group),
|
ID: generateGroupKey(group),
|
||||||
Servers: servers,
|
Servers: servers,
|
||||||
Domains: group.Domains,
|
Domains: group.Domains.ToPunycodeList(),
|
||||||
// The probe will determine the state, default enabled
|
// The probe will determine the state, default enabled
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Error: nil,
|
Error: nil,
|
||||||
@@ -825,7 +823,7 @@ func generateGroupKey(nsGroup *nbdns.NameServerGroup) string {
|
|||||||
|
|
||||||
// groupNSGroupsByDomain groups nameserver groups by their match domains
|
// groupNSGroupsByDomain groups nameserver groups by their match domains
|
||||||
func groupNSGroupsByDomain(nsGroups []*nbdns.NameServerGroup) []nsGroupsByDomain {
|
func groupNSGroupsByDomain(nsGroups []*nbdns.NameServerGroup) []nsGroupsByDomain {
|
||||||
domainMap := make(map[string][]*nbdns.NameServerGroup)
|
domainMap := make(map[domain.Domain][]*nbdns.NameServerGroup)
|
||||||
|
|
||||||
for _, group := range nsGroups {
|
for _, group := range nsGroups {
|
||||||
if group.Primary {
|
if group.Primary {
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -96,7 +95,7 @@ func init() {
|
|||||||
formatter.SetTextFormatter(log.StandardLogger())
|
formatter.SetTextFormatter(log.StandardLogger())
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase {
|
func generateDummyHandler(domain domain.Domain, servers []nbdns.NameServer) *upstreamResolverBase {
|
||||||
var srvs []string
|
var srvs []string
|
||||||
for _, srv := range servers {
|
for _, srv := range servers {
|
||||||
srvs = append(srvs, getNSHostPort(srv))
|
srvs = append(srvs, getNSHostPort(srv))
|
||||||
@@ -151,7 +150,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
},
|
},
|
||||||
NameServerGroups: []*nbdns.NameServerGroup{
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
{
|
{
|
||||||
Domains: []string{"netbird.io"},
|
Domains: domain.List{"netbird.io"},
|
||||||
NameServers: nameServers,
|
NameServers: nameServers,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -183,7 +182,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
name: "New Config Should Succeed",
|
name: "New Config Should Succeed",
|
||||||
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
|
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
|
||||||
initUpstreamMap: registeredHandlerMap{
|
initUpstreamMap: registeredHandlerMap{
|
||||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
generateDummyHandler(domain.Domain(zoneRecords[0].Name), nameServers).ID(): handlerWrapper{
|
||||||
domain: "netbird.cloud",
|
domain: "netbird.cloud",
|
||||||
handler: dummyHandler,
|
handler: dummyHandler,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityMatchDomain,
|
||||||
@@ -201,7 +200,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
},
|
},
|
||||||
NameServerGroups: []*nbdns.NameServerGroup{
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
{
|
{
|
||||||
Domains: []string{"netbird.io"},
|
Domains: domain.List{"netbird.io"},
|
||||||
NameServers: nameServers,
|
NameServers: nameServers,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -302,8 +301,8 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
name: "Empty Config Should Succeed and Clean Maps",
|
name: "Empty Config Should Succeed and Clean Maps",
|
||||||
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
|
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
|
||||||
initUpstreamMap: registeredHandlerMap{
|
initUpstreamMap: registeredHandlerMap{
|
||||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
generateDummyHandler(domain.Domain(zoneRecords[0].Name), nameServers).ID(): handlerWrapper{
|
||||||
domain: zoneRecords[0].Name,
|
domain: domain.Domain(zoneRecords[0].Name),
|
||||||
handler: dummyHandler,
|
handler: dummyHandler,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityMatchDomain,
|
||||||
},
|
},
|
||||||
@@ -318,8 +317,8 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
name: "Disabled Service Should clean map",
|
name: "Disabled Service Should clean map",
|
||||||
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
|
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
|
||||||
initUpstreamMap: registeredHandlerMap{
|
initUpstreamMap: registeredHandlerMap{
|
||||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
generateDummyHandler(domain.Domain(zoneRecords[0].Name), nameServers).ID(): handlerWrapper{
|
||||||
domain: zoneRecords[0].Name,
|
domain: domain.Domain(zoneRecords[0].Name),
|
||||||
handler: dummyHandler,
|
handler: dummyHandler,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityMatchDomain,
|
||||||
},
|
},
|
||||||
@@ -493,7 +492,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
|
|
||||||
dnsServer.dnsMuxMap = registeredHandlerMap{
|
dnsServer.dnsMuxMap = registeredHandlerMap{
|
||||||
"id1": handlerWrapper{
|
"id1": handlerWrapper{
|
||||||
domain: zoneRecords[0].Name,
|
domain: domain.Domain(zoneRecords[0].Name),
|
||||||
handler: &local.Resolver{},
|
handler: &local.Resolver{},
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityMatchDomain,
|
||||||
},
|
},
|
||||||
@@ -525,7 +524,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
},
|
},
|
||||||
NameServerGroups: []*nbdns.NameServerGroup{
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
{
|
{
|
||||||
Domains: []string{"netbird.io"},
|
Domains: domain.List{"netbird.io"},
|
||||||
NameServers: nameServers,
|
NameServers: nameServers,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -591,7 +590,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
|||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsServer.registerHandler([]string{"netbird.cloud"}, dnsServer.localResolver, 1)
|
dnsServer.registerHandler(domain.List{"netbird.cloud"}, dnsServer.localResolver, 1)
|
||||||
|
|
||||||
resolver := &net.Resolver{
|
resolver := &net.Resolver{
|
||||||
PreferGo: true,
|
PreferGo: true,
|
||||||
@@ -651,48 +650,48 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
|||||||
|
|
||||||
var domainsUpdate string
|
var domainsUpdate string
|
||||||
hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error {
|
hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error {
|
||||||
domains := []string{}
|
domains := domain.List{}
|
||||||
for _, item := range config.Domains {
|
for _, item := range config.Domains {
|
||||||
if item.Disabled {
|
if item.Disabled {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
domains = append(domains, item.Domain)
|
domains = append(domains, item.Domain)
|
||||||
}
|
}
|
||||||
domainsUpdate = strings.Join(domains, ",")
|
domainsUpdate = domains.PunycodeString()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
deactivate, reactivate := server.upstreamCallbacks(&nbdns.NameServerGroup{
|
deactivate, reactivate := server.upstreamCallbacks(&nbdns.NameServerGroup{
|
||||||
Domains: []string{"domain1"},
|
Domains: domain.List{"domain1"},
|
||||||
NameServers: []nbdns.NameServer{
|
NameServers: []nbdns.NameServer{
|
||||||
{IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53},
|
{IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||||
},
|
},
|
||||||
}, nil, 0)
|
}, nil, 0)
|
||||||
|
|
||||||
deactivate(nil)
|
deactivate(nil)
|
||||||
expected := "domain0,domain2"
|
expected := "domain0, domain2"
|
||||||
domains := []string{}
|
domains := domain.List{}
|
||||||
for _, item := range server.currentConfig.Domains {
|
for _, item := range server.currentConfig.Domains {
|
||||||
if item.Disabled {
|
if item.Disabled {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
domains = append(domains, item.Domain)
|
domains = append(domains, item.Domain)
|
||||||
}
|
}
|
||||||
got := strings.Join(domains, ",")
|
got := domains.PunycodeString()
|
||||||
if expected != got {
|
if expected != got {
|
||||||
t.Errorf("expected domains list: %q, got %q", expected, got)
|
t.Errorf("expected domains list: %q, got %q", expected, got)
|
||||||
}
|
}
|
||||||
|
|
||||||
reactivate()
|
reactivate()
|
||||||
expected = "domain0,domain1,domain2"
|
expected = "domain0, domain1, domain2"
|
||||||
domains = []string{}
|
domains = domain.List{}
|
||||||
for _, item := range server.currentConfig.Domains {
|
for _, item := range server.currentConfig.Domains {
|
||||||
if item.Disabled {
|
if item.Disabled {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
domains = append(domains, item.Domain)
|
domains = append(domains, item.Domain)
|
||||||
}
|
}
|
||||||
got = strings.Join(domains, ",")
|
got = domains.PunycodeString()
|
||||||
if expected != got {
|
if expected != got {
|
||||||
t.Errorf("expected domains list: %q, got %q", expected, domainsUpdate)
|
t.Errorf("expected domains list: %q, got %q", expected, domainsUpdate)
|
||||||
}
|
}
|
||||||
@@ -860,7 +859,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
|||||||
Port: 53,
|
Port: 53,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Domains: []string{"google.com"},
|
Domains: domain.List{"google.com"},
|
||||||
Primary: false,
|
Primary: false,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -1115,7 +1114,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
initialHandlers registeredHandlerMap
|
initialHandlers registeredHandlerMap
|
||||||
updates []handlerWrapper
|
updates []handlerWrapper
|
||||||
expectedHandlers map[string]string // map[HandlerID]domain
|
expectedHandlers map[string]domain.Domain // map[HandlerID]domain
|
||||||
description string
|
description string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
@@ -1131,7 +1130,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
priority: PriorityMatchDomain - 1,
|
priority: PriorityMatchDomain - 1,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedHandlers: map[string]string{
|
expectedHandlers: map[string]domain.Domain{
|
||||||
"upstream-group2": "example.com",
|
"upstream-group2": "example.com",
|
||||||
},
|
},
|
||||||
description: "When group1 is not included in the update, it should be removed while group2 remains",
|
description: "When group1 is not included in the update, it should be removed while group2 remains",
|
||||||
@@ -1149,7 +1148,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
priority: PriorityMatchDomain,
|
priority: PriorityMatchDomain,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedHandlers: map[string]string{
|
expectedHandlers: map[string]domain.Domain{
|
||||||
"upstream-group1": "example.com",
|
"upstream-group1": "example.com",
|
||||||
},
|
},
|
||||||
description: "When group2 is not included in the update, it should be removed while group1 remains",
|
description: "When group2 is not included in the update, it should be removed while group1 remains",
|
||||||
@@ -1182,7 +1181,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
priority: PriorityMatchDomain - 1,
|
priority: PriorityMatchDomain - 1,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedHandlers: map[string]string{
|
expectedHandlers: map[string]domain.Domain{
|
||||||
"upstream-group1": "example.com",
|
"upstream-group1": "example.com",
|
||||||
"upstream-group2": "example.com",
|
"upstream-group2": "example.com",
|
||||||
"upstream-group3": "example.com",
|
"upstream-group3": "example.com",
|
||||||
@@ -1217,7 +1216,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
priority: PriorityMatchDomain - 2,
|
priority: PriorityMatchDomain - 2,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedHandlers: map[string]string{
|
expectedHandlers: map[string]domain.Domain{
|
||||||
"upstream-group1": "example.com",
|
"upstream-group1": "example.com",
|
||||||
"upstream-group2": "example.com",
|
"upstream-group2": "example.com",
|
||||||
"upstream-group3": "example.com",
|
"upstream-group3": "example.com",
|
||||||
@@ -1237,7 +1236,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
priority: PriorityDefault - 1,
|
priority: PriorityDefault - 1,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedHandlers: map[string]string{
|
expectedHandlers: map[string]domain.Domain{
|
||||||
"upstream-root2": ".",
|
"upstream-root2": ".",
|
||||||
},
|
},
|
||||||
description: "When root1 is not included in the update, it should be removed while root2 remains",
|
description: "When root1 is not included in the update, it should be removed while root2 remains",
|
||||||
@@ -1254,7 +1253,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
priority: PriorityDefault,
|
priority: PriorityDefault,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedHandlers: map[string]string{
|
expectedHandlers: map[string]domain.Domain{
|
||||||
"upstream-root1": ".",
|
"upstream-root1": ".",
|
||||||
},
|
},
|
||||||
description: "When root2 is not included in the update, it should be removed while root1 remains",
|
description: "When root2 is not included in the update, it should be removed while root1 remains",
|
||||||
@@ -1285,7 +1284,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
priority: PriorityDefault - 1,
|
priority: PriorityDefault - 1,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedHandlers: map[string]string{
|
expectedHandlers: map[string]domain.Domain{
|
||||||
"upstream-root1": ".",
|
"upstream-root1": ".",
|
||||||
"upstream-root2": ".",
|
"upstream-root2": ".",
|
||||||
"upstream-root3": ".",
|
"upstream-root3": ".",
|
||||||
@@ -1318,7 +1317,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
priority: PriorityDefault - 2,
|
priority: PriorityDefault - 2,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedHandlers: map[string]string{
|
expectedHandlers: map[string]domain.Domain{
|
||||||
"upstream-root1": ".",
|
"upstream-root1": ".",
|
||||||
"upstream-root2": ".",
|
"upstream-root2": ".",
|
||||||
"upstream-root3": ".",
|
"upstream-root3": ".",
|
||||||
@@ -1345,7 +1344,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
priority: PriorityMatchDomain,
|
priority: PriorityMatchDomain,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedHandlers: map[string]string{
|
expectedHandlers: map[string]domain.Domain{
|
||||||
"upstream-group1": "example.com",
|
"upstream-group1": "example.com",
|
||||||
"upstream-other": "other.com",
|
"upstream-other": "other.com",
|
||||||
},
|
},
|
||||||
@@ -1384,7 +1383,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
priority: PriorityMatchDomain,
|
priority: PriorityMatchDomain,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedHandlers: map[string]string{
|
expectedHandlers: map[string]domain.Domain{
|
||||||
"upstream-group1": "example.com",
|
"upstream-group1": "example.com",
|
||||||
"upstream-group2": "example.com",
|
"upstream-group2": "example.com",
|
||||||
"upstream-other": "other.com",
|
"upstream-other": "other.com",
|
||||||
@@ -1440,7 +1439,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
for _, muxEntry := range server.dnsMuxMap {
|
for _, muxEntry := range server.dnsMuxMap {
|
||||||
if chainEntry.Handler == muxEntry.handler &&
|
if chainEntry.Handler == muxEntry.handler &&
|
||||||
chainEntry.Priority == muxEntry.priority &&
|
chainEntry.Priority == muxEntry.priority &&
|
||||||
chainEntry.Pattern == dns.Fqdn(muxEntry.domain) {
|
chainEntry.Pattern.PunycodeString() == dns.Fqdn(muxEntry.domain.PunycodeString()) {
|
||||||
foundInMux = true
|
foundInMux = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -1459,8 +1458,8 @@ func TestExtraDomains(t *testing.T) {
|
|||||||
registerDomains []domain.List
|
registerDomains []domain.List
|
||||||
deregisterDomains []domain.List
|
deregisterDomains []domain.List
|
||||||
finalConfig nbdns.Config
|
finalConfig nbdns.Config
|
||||||
expectedDomains []string
|
expectedDomains domain.List
|
||||||
expectedMatchOnly []string
|
expectedMatchOnly domain.List
|
||||||
applyHostConfigCall int
|
applyHostConfigCall int
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
@@ -1474,12 +1473,12 @@ func TestExtraDomains(t *testing.T) {
|
|||||||
{Domain: "config.example.com"},
|
{Domain: "config.example.com"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedDomains: []string{
|
expectedDomains: domain.List{
|
||||||
"config.example.com.",
|
"config.example.com.",
|
||||||
"extra1.example.com.",
|
"extra1.example.com.",
|
||||||
"extra2.example.com.",
|
"extra2.example.com.",
|
||||||
},
|
},
|
||||||
expectedMatchOnly: []string{
|
expectedMatchOnly: domain.List{
|
||||||
"extra1.example.com.",
|
"extra1.example.com.",
|
||||||
"extra2.example.com.",
|
"extra2.example.com.",
|
||||||
},
|
},
|
||||||
@@ -1496,12 +1495,12 @@ func TestExtraDomains(t *testing.T) {
|
|||||||
registerDomains: []domain.List{
|
registerDomains: []domain.List{
|
||||||
{"extra1.example.com", "extra2.example.com"},
|
{"extra1.example.com", "extra2.example.com"},
|
||||||
},
|
},
|
||||||
expectedDomains: []string{
|
expectedDomains: domain.List{
|
||||||
"config.example.com.",
|
"config.example.com.",
|
||||||
"extra1.example.com.",
|
"extra1.example.com.",
|
||||||
"extra2.example.com.",
|
"extra2.example.com.",
|
||||||
},
|
},
|
||||||
expectedMatchOnly: []string{
|
expectedMatchOnly: domain.List{
|
||||||
"extra1.example.com.",
|
"extra1.example.com.",
|
||||||
"extra2.example.com.",
|
"extra2.example.com.",
|
||||||
},
|
},
|
||||||
@@ -1519,12 +1518,12 @@ func TestExtraDomains(t *testing.T) {
|
|||||||
registerDomains: []domain.List{
|
registerDomains: []domain.List{
|
||||||
{"extra.example.com", "overlap.example.com"},
|
{"extra.example.com", "overlap.example.com"},
|
||||||
},
|
},
|
||||||
expectedDomains: []string{
|
expectedDomains: domain.List{
|
||||||
"config.example.com.",
|
"config.example.com.",
|
||||||
"overlap.example.com.",
|
"overlap.example.com.",
|
||||||
"extra.example.com.",
|
"extra.example.com.",
|
||||||
},
|
},
|
||||||
expectedMatchOnly: []string{
|
expectedMatchOnly: domain.List{
|
||||||
"extra.example.com.",
|
"extra.example.com.",
|
||||||
},
|
},
|
||||||
applyHostConfigCall: 2,
|
applyHostConfigCall: 2,
|
||||||
@@ -1544,12 +1543,12 @@ func TestExtraDomains(t *testing.T) {
|
|||||||
deregisterDomains: []domain.List{
|
deregisterDomains: []domain.List{
|
||||||
{"extra1.example.com", "extra3.example.com"},
|
{"extra1.example.com", "extra3.example.com"},
|
||||||
},
|
},
|
||||||
expectedDomains: []string{
|
expectedDomains: domain.List{
|
||||||
"config.example.com.",
|
"config.example.com.",
|
||||||
"extra2.example.com.",
|
"extra2.example.com.",
|
||||||
"extra4.example.com.",
|
"extra4.example.com.",
|
||||||
},
|
},
|
||||||
expectedMatchOnly: []string{
|
expectedMatchOnly: domain.List{
|
||||||
"extra2.example.com.",
|
"extra2.example.com.",
|
||||||
"extra4.example.com.",
|
"extra4.example.com.",
|
||||||
},
|
},
|
||||||
@@ -1570,13 +1569,13 @@ func TestExtraDomains(t *testing.T) {
|
|||||||
deregisterDomains: []domain.List{
|
deregisterDomains: []domain.List{
|
||||||
{"duplicate.example.com"},
|
{"duplicate.example.com"},
|
||||||
},
|
},
|
||||||
expectedDomains: []string{
|
expectedDomains: domain.List{
|
||||||
"config.example.com.",
|
"config.example.com.",
|
||||||
"extra.example.com.",
|
"extra.example.com.",
|
||||||
"other.example.com.",
|
"other.example.com.",
|
||||||
"duplicate.example.com.",
|
"duplicate.example.com.",
|
||||||
},
|
},
|
||||||
expectedMatchOnly: []string{
|
expectedMatchOnly: domain.List{
|
||||||
"extra.example.com.",
|
"extra.example.com.",
|
||||||
"other.example.com.",
|
"other.example.com.",
|
||||||
"duplicate.example.com.",
|
"duplicate.example.com.",
|
||||||
@@ -1601,13 +1600,13 @@ func TestExtraDomains(t *testing.T) {
|
|||||||
{Domain: "newconfig.example.com"},
|
{Domain: "newconfig.example.com"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedDomains: []string{
|
expectedDomains: domain.List{
|
||||||
"config.example.com.",
|
"config.example.com.",
|
||||||
"newconfig.example.com.",
|
"newconfig.example.com.",
|
||||||
"extra.example.com.",
|
"extra.example.com.",
|
||||||
"duplicate.example.com.",
|
"duplicate.example.com.",
|
||||||
},
|
},
|
||||||
expectedMatchOnly: []string{
|
expectedMatchOnly: domain.List{
|
||||||
"extra.example.com.",
|
"extra.example.com.",
|
||||||
"duplicate.example.com.",
|
"duplicate.example.com.",
|
||||||
},
|
},
|
||||||
@@ -1628,12 +1627,12 @@ func TestExtraDomains(t *testing.T) {
|
|||||||
deregisterDomains: []domain.List{
|
deregisterDomains: []domain.List{
|
||||||
{"protected.example.com"},
|
{"protected.example.com"},
|
||||||
},
|
},
|
||||||
expectedDomains: []string{
|
expectedDomains: domain.List{
|
||||||
"extra.example.com.",
|
"extra.example.com.",
|
||||||
"config.example.com.",
|
"config.example.com.",
|
||||||
"protected.example.com.",
|
"protected.example.com.",
|
||||||
},
|
},
|
||||||
expectedMatchOnly: []string{
|
expectedMatchOnly: domain.List{
|
||||||
"extra.example.com.",
|
"extra.example.com.",
|
||||||
},
|
},
|
||||||
applyHostConfigCall: 3,
|
applyHostConfigCall: 3,
|
||||||
@@ -1644,7 +1643,7 @@ func TestExtraDomains(t *testing.T) {
|
|||||||
ServiceEnable: true,
|
ServiceEnable: true,
|
||||||
NameServerGroups: []*nbdns.NameServerGroup{
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
{
|
{
|
||||||
Domains: []string{"ns.example.com", "overlap.ns.example.com"},
|
Domains: domain.List{"ns.example.com", "overlap.ns.example.com"},
|
||||||
NameServers: []nbdns.NameServer{
|
NameServers: []nbdns.NameServer{
|
||||||
{
|
{
|
||||||
IP: netip.MustParseAddr("8.8.8.8"),
|
IP: netip.MustParseAddr("8.8.8.8"),
|
||||||
@@ -1658,12 +1657,12 @@ func TestExtraDomains(t *testing.T) {
|
|||||||
registerDomains: []domain.List{
|
registerDomains: []domain.List{
|
||||||
{"extra.example.com", "overlap.ns.example.com"},
|
{"extra.example.com", "overlap.ns.example.com"},
|
||||||
},
|
},
|
||||||
expectedDomains: []string{
|
expectedDomains: domain.List{
|
||||||
"ns.example.com.",
|
"ns.example.com.",
|
||||||
"overlap.ns.example.com.",
|
"overlap.ns.example.com.",
|
||||||
"extra.example.com.",
|
"extra.example.com.",
|
||||||
},
|
},
|
||||||
expectedMatchOnly: []string{
|
expectedMatchOnly: domain.List{
|
||||||
"ns.example.com.",
|
"ns.example.com.",
|
||||||
"overlap.ns.example.com.",
|
"overlap.ns.example.com.",
|
||||||
"extra.example.com.",
|
"extra.example.com.",
|
||||||
@@ -1734,8 +1733,8 @@ func TestExtraDomains(t *testing.T) {
|
|||||||
lastConfig := capturedConfigs[len(capturedConfigs)-1]
|
lastConfig := capturedConfigs[len(capturedConfigs)-1]
|
||||||
|
|
||||||
// Check all expected domains are present
|
// Check all expected domains are present
|
||||||
domainMap := make(map[string]bool)
|
domainMap := make(map[domain.Domain]bool)
|
||||||
matchOnlyMap := make(map[string]bool)
|
matchOnlyMap := make(map[domain.Domain]bool)
|
||||||
|
|
||||||
for _, d := range lastConfig.Domains {
|
for _, d := range lastConfig.Domains {
|
||||||
domainMap[d.Domain] = true
|
domainMap[d.Domain] = true
|
||||||
@@ -1852,12 +1851,12 @@ func TestUpdateConfigWithExistingExtraDomains(t *testing.T) {
|
|||||||
err := server.applyConfiguration(initialConfig)
|
err := server.applyConfiguration(initialConfig)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
var domains []string
|
var domains domain.List
|
||||||
for _, d := range capturedConfig.Domains {
|
for _, d := range capturedConfig.Domains {
|
||||||
domains = append(domains, d.Domain)
|
domains = append(domains, d.Domain)
|
||||||
}
|
}
|
||||||
assert.Contains(t, domains, "config.example.com.")
|
assert.Contains(t, domains, domain.Domain("config.example.com."))
|
||||||
assert.Contains(t, domains, "extra.example.com.")
|
assert.Contains(t, domains, domain.Domain("extra.example.com."))
|
||||||
|
|
||||||
// Now apply a new configuration with overlapping domain
|
// Now apply a new configuration with overlapping domain
|
||||||
updatedConfig := nbdns.Config{
|
updatedConfig := nbdns.Config{
|
||||||
@@ -1871,7 +1870,7 @@ func TestUpdateConfigWithExistingExtraDomains(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// Verify both domains are in config, but no duplicates
|
// Verify both domains are in config, but no duplicates
|
||||||
domains = []string{}
|
domains = domain.List{}
|
||||||
matchOnlyCount := 0
|
matchOnlyCount := 0
|
||||||
for _, d := range capturedConfig.Domains {
|
for _, d := range capturedConfig.Domains {
|
||||||
domains = append(domains, d.Domain)
|
domains = append(domains, d.Domain)
|
||||||
@@ -1880,12 +1879,12 @@ func TestUpdateConfigWithExistingExtraDomains(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Contains(t, domains, "config.example.com.")
|
assert.Contains(t, domains, domain.Domain("config.example.com."))
|
||||||
assert.Contains(t, domains, "extra.example.com.")
|
assert.Contains(t, domains, domain.Domain("extra.example.com."))
|
||||||
assert.Equal(t, 2, len(domains), "Should have exactly 2 domains with no duplicates")
|
assert.Equal(t, 2, len(domains), "Should have exactly 2 domains with no duplicates")
|
||||||
|
|
||||||
// Extra domain should no longer be marked as match-only when in config
|
// Extra domain should no longer be marked as match-only when in config
|
||||||
matchOnlyDomain := ""
|
var matchOnlyDomain domain.Domain
|
||||||
for _, d := range capturedConfig.Domains {
|
for _, d := range capturedConfig.Domains {
|
||||||
if d.Domain == "extra.example.com." && d.MatchOnly {
|
if d.Domain == "extra.example.com." && d.MatchOnly {
|
||||||
matchOnlyDomain = d.Domain
|
matchOnlyDomain = d.Domain
|
||||||
@@ -1938,10 +1937,10 @@ func TestDomainCaseHandling(t *testing.T) {
|
|||||||
err := server.applyConfiguration(config)
|
err := server.applyConfiguration(config)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
var domains []string
|
var domains domain.List
|
||||||
for _, d := range capturedConfig.Domains {
|
for _, d := range capturedConfig.Domains {
|
||||||
domains = append(domains, d.Domain)
|
domains = append(domains, d.Domain)
|
||||||
}
|
}
|
||||||
assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent")
|
assert.Contains(t, domains, domain.Domain("config.example.com."), "Mixed case domain should be normalized and pre.sent")
|
||||||
assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present")
|
assert.Contains(t, domains, domain.Domain("mixed.example.com."), "Mixed case domain should be normalized and present")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -117,15 +117,15 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
|
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
|
||||||
Domain: dConf.Domain,
|
Domain: dConf.Domain.PunycodeString(),
|
||||||
MatchOnly: dConf.MatchOnly,
|
MatchOnly: dConf.MatchOnly,
|
||||||
})
|
})
|
||||||
|
|
||||||
if dConf.MatchOnly {
|
if dConf.MatchOnly {
|
||||||
matchDomains = append(matchDomains, dConf.Domain)
|
matchDomains = append(matchDomains, dConf.Domain.PunycodeString())
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
searchDomains = append(searchDomains, dConf.Domain)
|
searchDomains = append(searchDomains, dConf.Domain.PunycodeString())
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.RouteAll {
|
if config.RouteAll {
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -48,7 +49,7 @@ type upstreamResolverBase struct {
|
|||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
upstreamClient upstreamClient
|
upstreamClient upstreamClient
|
||||||
upstreamServers []string
|
upstreamServers []string
|
||||||
domain string
|
domain domain.Domain
|
||||||
disabled bool
|
disabled bool
|
||||||
failsCount atomic.Int32
|
failsCount atomic.Int32
|
||||||
successCount atomic.Int32
|
successCount atomic.Int32
|
||||||
@@ -62,7 +63,7 @@ type upstreamResolverBase struct {
|
|||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase {
|
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain domain.Domain) *upstreamResolverBase {
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
return &upstreamResolverBase{
|
return &upstreamResolverBase{
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -28,7 +29,7 @@ func newUpstreamResolver(
|
|||||||
_ netip.Prefix,
|
_ netip.Prefix,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
hostsDNSHolder *hostsDNSHolder,
|
hostsDNSHolder *hostsDNSHolder,
|
||||||
domain string,
|
domain domain.Domain,
|
||||||
) (*upstreamResolver, error) {
|
) (*upstreamResolver, error) {
|
||||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
||||||
c := &upstreamResolver{
|
c := &upstreamResolver{
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
type upstreamResolver struct {
|
type upstreamResolver struct {
|
||||||
@@ -23,7 +24,7 @@ func newUpstreamResolver(
|
|||||||
_ netip.Prefix,
|
_ netip.Prefix,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
_ *hostsDNSHolder,
|
_ *hostsDNSHolder,
|
||||||
domain string,
|
domain domain.Domain,
|
||||||
) (*upstreamResolver, error) {
|
) (*upstreamResolver, error) {
|
||||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
||||||
nonIOS := &upstreamResolver{
|
nonIOS := &upstreamResolver{
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
type upstreamResolverIOS struct {
|
type upstreamResolverIOS struct {
|
||||||
@@ -31,7 +32,7 @@ func newUpstreamResolver(
|
|||||||
net netip.Prefix,
|
net netip.Prefix,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
_ *hostsDNSHolder,
|
_ *hostsDNSHolder,
|
||||||
domain string,
|
domain domain.Domain,
|
||||||
) (*upstreamResolverIOS, error) {
|
) (*upstreamResolverIOS, error) {
|
||||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
||||||
|
|
||||||
|
|||||||
@@ -1165,7 +1165,7 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns
|
|||||||
for _, nsGroup := range protoDNSConfig.GetNameServerGroups() {
|
for _, nsGroup := range protoDNSConfig.GetNameServerGroups() {
|
||||||
dnsNSGroup := &nbdns.NameServerGroup{
|
dnsNSGroup := &nbdns.NameServerGroup{
|
||||||
Primary: nsGroup.GetPrimary(),
|
Primary: nsGroup.GetPrimary(),
|
||||||
Domains: nsGroup.GetDomains(),
|
Domains: domain.FromPunycodeList(nsGroup.GetDomains()),
|
||||||
SearchDomainsEnabled: nsGroup.GetSearchDomainsEnabled(),
|
SearchDomainsEnabled: nsGroup.GetSearchDomainsEnabled(),
|
||||||
}
|
}
|
||||||
for _, ns := range nsGroup.GetNameServers() {
|
for _, ns := range nsGroup.GetNameServers() {
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
mgmt "github.com/netbirdio/netbird/management/client"
|
mgmt "github.com/netbirdio/netbird/management/client"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
@@ -77,7 +78,7 @@ var (
|
|||||||
|
|
||||||
type MockWGIface struct {
|
type MockWGIface struct {
|
||||||
CreateFunc func() error
|
CreateFunc func() error
|
||||||
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
|
CreateOnAndroidFunc func(routeRange []string, ip string, domains domain.List) error
|
||||||
IsUserspaceBindFunc func() bool
|
IsUserspaceBindFunc func() bool
|
||||||
NameFunc func() string
|
NameFunc func() string
|
||||||
AddressFunc func() wgaddr.Address
|
AddressFunc func() wgaddr.Address
|
||||||
@@ -111,7 +112,7 @@ func (m *MockWGIface) Create() error {
|
|||||||
return m.CreateFunc()
|
return m.CreateFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains []string) error {
|
func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains domain.List) error {
|
||||||
return m.CreateOnAndroidFunc(routeRange, ip, domains)
|
return m.CreateOnAndroidFunc(routeRange, ip, domains)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,11 +14,12 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
type wgIfaceBase interface {
|
type wgIfaceBase interface {
|
||||||
Create() error
|
Create() error
|
||||||
CreateOnAndroid(routeRange []string, ip string, domains []string) error
|
CreateOnAndroid(routeRange []string, ip string, domains domain.List) error
|
||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
Name() string
|
Name() string
|
||||||
Address() wgaddr.Address
|
Address() wgaddr.Address
|
||||||
|
|||||||
@@ -229,15 +229,14 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(r.Answer) > 0 && len(r.Question) > 0 {
|
if len(r.Answer) > 0 && len(r.Question) > 0 {
|
||||||
origPattern := ""
|
var origPattern domain.Domain
|
||||||
if writer, ok := w.(*nbdns.ResponseWriterChain); ok {
|
if writer, ok := w.(*nbdns.ResponseWriterChain); ok {
|
||||||
origPattern = writer.GetOrigPattern()
|
origPattern = writer.GetOrigPattern()
|
||||||
}
|
}
|
||||||
|
|
||||||
resolvedDomain := domain.Domain(strings.ToLower(r.Question[0].Name))
|
resolvedDomain := domain.Domain(strings.ToLower(r.Question[0].Name))
|
||||||
|
|
||||||
// already punycode via RegisterHandler()
|
originalDomain := origPattern
|
||||||
originalDomain := domain.Domain(origPattern)
|
|
||||||
if originalDomain == "" {
|
if originalDomain == "" {
|
||||||
originalDomain = resolvedDomain
|
originalDomain = resolvedDomain
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -64,7 +66,7 @@ type NameServerGroup struct {
|
|||||||
// Primary indicates that the nameserver group is the primary resolver for any dns query
|
// Primary indicates that the nameserver group is the primary resolver for any dns query
|
||||||
Primary bool
|
Primary bool
|
||||||
// Domains indicate the dns query domains to use with this nameserver group
|
// Domains indicate the dns query domains to use with this nameserver group
|
||||||
Domains []string `gorm:"serializer:json"`
|
Domains domain.List `gorm:"serializer:json"`
|
||||||
// Enabled group status
|
// Enabled group status
|
||||||
Enabled bool
|
Enabled bool
|
||||||
// SearchDomainsEnabled indicates whether to add match domains to search domains list or not
|
// SearchDomainsEnabled indicates whether to add match domains to search domains list or not
|
||||||
@@ -142,7 +144,7 @@ func (g *NameServerGroup) Copy() *NameServerGroup {
|
|||||||
Groups: make([]string, len(g.Groups)),
|
Groups: make([]string, len(g.Groups)),
|
||||||
Enabled: g.Enabled,
|
Enabled: g.Enabled,
|
||||||
Primary: g.Primary,
|
Primary: g.Primary,
|
||||||
Domains: make([]string, len(g.Domains)),
|
Domains: make(domain.List, len(g.Domains)),
|
||||||
SearchDomainsEnabled: g.SearchDomainsEnabled,
|
SearchDomainsEnabled: g.SearchDomainsEnabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -188,7 +190,7 @@ func containsNameServer(element NameServer, list []NameServer) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func compareGroupsList(list, other []string) bool {
|
func compareGroupsList[T comparable](list, other []T) bool {
|
||||||
if len(list) != len(other) {
|
if len(list) != len(other) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ func (d Domain) SafeString() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// PunycodeString returns the punycode representation of the Domain.
|
// PunycodeString returns the punycode representation of the Domain.
|
||||||
// This should only be used if a punycode domain is expected but only a string is supported.
|
// This should only be used if a punycode domain is expected but only a string is supported (e.g. an external library).
|
||||||
func (d Domain) PunycodeString() string {
|
func (d Domain) PunycodeString() string {
|
||||||
return string(d)
|
return string(d)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package domain
|
package domain
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sort"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -41,6 +41,7 @@ func (d List) ToSafeStringList() []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// String converts List to a comma-separated string.
|
// String converts List to a comma-separated string.
|
||||||
|
// This is useful for displaying domain names in a user-friendly format.
|
||||||
func (d List) String() (string, error) {
|
func (d List) String() (string, error) {
|
||||||
list, err := d.ToStringList()
|
list, err := d.ToStringList()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -50,7 +51,8 @@ func (d List) String() (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SafeString converts List to a comma-separated non-punycode string.
|
// SafeString converts List to a comma-separated non-punycode string.
|
||||||
// If a domain cannot be converted, the original string is used.
|
// This is useful for displaying domain names in a user-friendly format.
|
||||||
|
// If a domain cannot be converted, the original (punycode) string is used.
|
||||||
func (d List) SafeString() string {
|
func (d List) SafeString() string {
|
||||||
str, err := d.String()
|
str, err := d.String()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -64,28 +66,22 @@ func (d List) PunycodeString() string {
|
|||||||
return strings.Join(d.ToPunycodeList(), ", ")
|
return strings.Join(d.ToPunycodeList(), ", ")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Equal checks if two domain lists are equal without considering the order.
|
||||||
func (d List) Equal(domains List) bool {
|
func (d List) Equal(domains List) bool {
|
||||||
if len(d) != len(domains) {
|
if len(d) != len(domains) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
sort.Slice(d, func(i, j int) bool {
|
d1 := slices.Clone(d)
|
||||||
return d[i] < d[j]
|
d2 := slices.Clone(domains)
|
||||||
})
|
|
||||||
|
|
||||||
sort.Slice(domains, func(i, j int) bool {
|
slices.Sort(d1)
|
||||||
return domains[i] < domains[j]
|
slices.Sort(d2)
|
||||||
})
|
|
||||||
|
|
||||||
for i, domain := range d {
|
return slices.Equal(d1, d2)
|
||||||
if domain != domains[i] {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// FromStringList creates a DomainList from a slice of string.
|
// FromStringList creates a List from a slice of strings.
|
||||||
func FromStringList(s []string) (List, error) {
|
func FromStringList(s []string) (List, error) {
|
||||||
var dl List
|
var dl List
|
||||||
for _, domain := range s {
|
for _, domain := range s {
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ type Manager interface {
|
|||||||
DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error
|
DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error
|
||||||
ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error)
|
ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error)
|
||||||
GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
||||||
CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
|
CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains domain.List, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
|
||||||
SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
||||||
DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error
|
DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error
|
||||||
ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
|
ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
|
||||||
|
|||||||
@@ -19,6 +19,9 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
@@ -1688,7 +1691,7 @@ func TestAccount_Copy(t *testing.T) {
|
|||||||
NameServerGroups: map[string]*nbdns.NameServerGroup{
|
NameServerGroups: map[string]*nbdns.NameServerGroup{
|
||||||
"nsGroup1": {
|
"nsGroup1": {
|
||||||
ID: "nsGroup1",
|
ID: "nsGroup1",
|
||||||
Domains: []string{},
|
Domains: domain.List{},
|
||||||
Groups: []string{},
|
Groups: []string{},
|
||||||
NameServers: []nbdns.NameServer{},
|
NameServers: []nbdns.NameServer{},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -258,7 +258,7 @@ func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
|
|||||||
func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
|
func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
|
||||||
protoGroup := &proto.NameServerGroup{
|
protoGroup := &proto.NameServerGroup{
|
||||||
Primary: nsGroup.Primary,
|
Primary: nsGroup.Primary,
|
||||||
Domains: nsGroup.Domains,
|
Domains: nsGroup.Domains.ToPunycodeList(),
|
||||||
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
|
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
|
||||||
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
|
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
@@ -367,7 +368,7 @@ func generateTestData(size int) nbdns.Config {
|
|||||||
config.NameServerGroups[i] = &nbdns.NameServerGroup{
|
config.NameServerGroups[i] = &nbdns.NameServerGroup{
|
||||||
ID: fmt.Sprintf("group%d", i),
|
ID: fmt.Sprintf("group%d", i),
|
||||||
Primary: i == 0,
|
Primary: i == 0,
|
||||||
Domains: []string{fmt.Sprintf("domain%d.com", i)},
|
Domains: domain.List{domain.Domain(fmt.Sprintf("domain%d.com", i))},
|
||||||
SearchDomainsEnabled: true,
|
SearchDomainsEnabled: true,
|
||||||
NameServers: []nbdns.NameServer{
|
NameServers: []nbdns.NameServer{
|
||||||
{
|
{
|
||||||
@@ -547,7 +548,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
|||||||
Port: dns.DefaultDNSPort,
|
Port: dns.DefaultDNSPort,
|
||||||
}},
|
}},
|
||||||
[]string{"groupB"},
|
[]string{"groupB"},
|
||||||
true, []string{}, true, userID, false,
|
true, domain.List{}, true, userID, false,
|
||||||
)
|
)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
@@ -580,7 +581,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
|||||||
Port: dns.DefaultDNSPort,
|
Port: dns.DefaultDNSPort,
|
||||||
}},
|
}},
|
||||||
[]string{"groupA"},
|
[]string{"groupA"},
|
||||||
true, []string{}, true, userID, false,
|
true, domain.List{}, true, userID, false,
|
||||||
)
|
)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
@@ -83,7 +84,13 @@ func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *htt
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), accountID, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userID, req.SearchDomainsEnabled)
|
domains, err := domain.FromStringList(req.Domains)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains format"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), accountID, req.Name, req.Description, nsList, req.Groups, req.Primary, domains, req.Enabled, userID, req.SearchDomainsEnabled)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -123,12 +130,18 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
domains, err := domain.FromStringList(req.Domains)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains format"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
updatedNSGroup := &nbdns.NameServerGroup{
|
updatedNSGroup := &nbdns.NameServerGroup{
|
||||||
ID: nsGroupID,
|
ID: nsGroupID,
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
Description: req.Description,
|
Description: req.Description,
|
||||||
Primary: req.Primary,
|
Primary: req.Primary,
|
||||||
Domains: req.Domains,
|
Domains: domains,
|
||||||
NameServers: nsList,
|
NameServers: nsList,
|
||||||
Groups: req.Groups,
|
Groups: req.Groups,
|
||||||
Enabled: req.Enabled,
|
Enabled: req.Enabled,
|
||||||
@@ -227,7 +240,7 @@ func toNameserverGroupResponse(serverNSGroup *nbdns.NameServerGroup) *api.Namese
|
|||||||
Name: serverNSGroup.Name,
|
Name: serverNSGroup.Name,
|
||||||
Description: serverNSGroup.Description,
|
Description: serverNSGroup.Description,
|
||||||
Primary: serverNSGroup.Primary,
|
Primary: serverNSGroup.Primary,
|
||||||
Domains: serverNSGroup.Domains,
|
Domains: serverNSGroup.Domains.ToSafeStringList(),
|
||||||
Groups: serverNSGroup.Groups,
|
Groups: serverNSGroup.Groups,
|
||||||
Nameservers: nsList,
|
Nameservers: nsList,
|
||||||
Enabled: serverNSGroup.Enabled,
|
Enabled: serverNSGroup.Enabled,
|
||||||
|
|||||||
@@ -10,17 +10,15 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
|
||||||
|
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -47,7 +45,7 @@ var baseExistingNSGroup = &nbdns.NameServerGroup{
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
Groups: []string{"testing"},
|
Groups: []string{"testing"},
|
||||||
Domains: []string{"domain"},
|
Domains: domain.List{"domain"},
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -60,7 +58,7 @@ func initNameserversTestData() *nameserversHandler {
|
|||||||
}
|
}
|
||||||
return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID)
|
return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID)
|
||||||
},
|
},
|
||||||
CreateNameServerGroupFunc: func(_ context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, _ string, searchDomains bool) (*nbdns.NameServerGroup, error) {
|
CreateNameServerGroupFunc: func(_ context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains domain.List, enabled bool, _ string, searchDomains bool) (*nbdns.NameServerGroup, error) {
|
||||||
return &nbdns.NameServerGroup{
|
return &nbdns.NameServerGroup{
|
||||||
ID: existingNSGroupID,
|
ID: existingNSGroupID,
|
||||||
Name: name,
|
Name: name,
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ type MockAccountManager struct {
|
|||||||
GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*types.PersonalAccessToken, error)
|
GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*types.PersonalAccessToken, error)
|
||||||
GetAllPATsFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string) ([]*types.PersonalAccessToken, error)
|
GetAllPATsFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string) ([]*types.PersonalAccessToken, error)
|
||||||
GetNameServerGroupFunc func(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
GetNameServerGroupFunc func(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
||||||
CreateNameServerGroupFunc func(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
|
CreateNameServerGroupFunc func(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains domain.List, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
|
||||||
SaveNameServerGroupFunc func(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
SaveNameServerGroupFunc func(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
||||||
DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error
|
DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error
|
||||||
ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
|
ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
|
||||||
@@ -567,7 +567,7 @@ func (am *MockAccountManager) GetNameServerGroup(ctx context.Context, accountID,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreateNameServerGroup mocks CreateNameServerGroup of the AccountManager interface
|
// CreateNameServerGroup mocks CreateNameServerGroup of the AccountManager interface
|
||||||
func (am *MockAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) {
|
func (am *MockAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains domain.List, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) {
|
||||||
if am.CreateNameServerGroupFunc != nil {
|
if am.CreateNameServerGroupFunc != nil {
|
||||||
return am.CreateNameServerGroupFunc(ctx, accountID, name, description, nameServerList, groups, primary, domains, enabled, userID, searchDomainsEnabled)
|
return am.CreateNameServerGroupFunc(ctx, accountID, name, description, nameServerList, groups, primary, domains, enabled, userID, searchDomainsEnabled)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
@@ -18,7 +19,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*[*.a-z]{1,}$`
|
const domainPattern = `^(?i)[a-z0-9]+([\-]+[a-z0-9]+)*[*.a-z]{1,}$`
|
||||||
|
|
||||||
var invalidDomainName = errors.New("invalid domain name")
|
var invalidDomainName = errors.New("invalid domain name")
|
||||||
|
|
||||||
@@ -36,7 +37,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreateNameServerGroup creates and saves a new nameserver group
|
// CreateNameServerGroup creates and saves a new nameserver group
|
||||||
func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
|
func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains domain.List, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
@@ -252,7 +253,7 @@ func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction store
|
|||||||
return anyGroupHasPeersOrResources(ctx, transaction, oldNSGroup.AccountID, oldNSGroup.Groups)
|
return anyGroupHasPeersOrResources(ctx, transaction, oldNSGroup.AccountID, oldNSGroup.Groups)
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error {
|
func validateDomainInput(primary bool, domains domain.List, searchDomainsEnabled bool) error {
|
||||||
if !primary && len(domains) == 0 {
|
if !primary && len(domains) == 0 {
|
||||||
return status.Errorf(status.InvalidArgument, "nameserver group primary status is false and domains are empty,"+
|
return status.Errorf(status.InvalidArgument, "nameserver group primary status is false and domains are empty,"+
|
||||||
" it should be primary or have at least one domain")
|
" it should be primary or have at least one domain")
|
||||||
@@ -268,7 +269,7 @@ func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bo
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, domain := range domains {
|
for _, domain := range domains {
|
||||||
if err := validateDomain(domain); err != nil {
|
if err := validateDomain(domain.PunycodeString()); err != nil {
|
||||||
return status.Errorf(status.InvalidArgument, "nameserver group got an invalid domain: %s %q", domain, err)
|
return status.Errorf(status.InvalidArgument, "nameserver group got an invalid domain: %s %q", domain, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
@@ -41,7 +42,7 @@ func TestCreateNameServerGroup(t *testing.T) {
|
|||||||
groups []string
|
groups []string
|
||||||
nameServers []nbdns.NameServer
|
nameServers []nbdns.NameServer
|
||||||
primary bool
|
primary bool
|
||||||
domains []string
|
domains domain.List
|
||||||
searchDomains bool
|
searchDomains bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -102,7 +103,7 @@ func TestCreateNameServerGroup(t *testing.T) {
|
|||||||
description: "super",
|
description: "super",
|
||||||
groups: []string{group1ID},
|
groups: []string{group1ID},
|
||||||
primary: false,
|
primary: false,
|
||||||
domains: []string{validDomain},
|
domains: domain.List{validDomain},
|
||||||
nameServers: []nbdns.NameServer{
|
nameServers: []nbdns.NameServer{
|
||||||
{
|
{
|
||||||
IP: netip.MustParseAddr("1.1.1.1"),
|
IP: netip.MustParseAddr("1.1.1.1"),
|
||||||
@@ -123,7 +124,7 @@ func TestCreateNameServerGroup(t *testing.T) {
|
|||||||
Name: "super",
|
Name: "super",
|
||||||
Description: "super",
|
Description: "super",
|
||||||
Primary: false,
|
Primary: false,
|
||||||
Domains: []string{"example.com"},
|
Domains: domain.List{"example.com"},
|
||||||
Groups: []string{group1ID},
|
Groups: []string{group1ID},
|
||||||
NameServers: []nbdns.NameServer{
|
NameServers: []nbdns.NameServer{
|
||||||
{
|
{
|
||||||
@@ -360,7 +361,7 @@ func TestCreateNameServerGroup(t *testing.T) {
|
|||||||
name: "super",
|
name: "super",
|
||||||
description: "super",
|
description: "super",
|
||||||
groups: []string{group1ID},
|
groups: []string{group1ID},
|
||||||
domains: []string{invalidDomain},
|
domains: domain.List{invalidDomain},
|
||||||
nameServers: []nbdns.NameServer{
|
nameServers: []nbdns.NameServer{
|
||||||
{
|
{
|
||||||
IP: netip.MustParseAddr("1.1.1.1"),
|
IP: netip.MustParseAddr("1.1.1.1"),
|
||||||
@@ -447,8 +448,8 @@ func TestSaveNameServerGroup(t *testing.T) {
|
|||||||
validGroups := []string{group2ID}
|
validGroups := []string{group2ID}
|
||||||
invalidGroups := []string{"nonExisting"}
|
invalidGroups := []string{"nonExisting"}
|
||||||
disabledPrimary := false
|
disabledPrimary := false
|
||||||
validDomains := []string{validDomain}
|
validDomains := domain.List{validDomain}
|
||||||
invalidDomains := []string{invalidDomain}
|
invalidDomains := domain.List{invalidDomain}
|
||||||
|
|
||||||
validNameServerList := []nbdns.NameServer{
|
validNameServerList := []nbdns.NameServer{
|
||||||
{
|
{
|
||||||
@@ -491,7 +492,7 @@ func TestSaveNameServerGroup(t *testing.T) {
|
|||||||
newID *string
|
newID *string
|
||||||
newName *string
|
newName *string
|
||||||
newPrimary *bool
|
newPrimary *bool
|
||||||
newDomains []string
|
newDomains domain.List
|
||||||
newNSList []nbdns.NameServer
|
newNSList []nbdns.NameServer
|
||||||
newGroups []string
|
newGroups []string
|
||||||
skipCopying bool
|
skipCopying bool
|
||||||
@@ -908,6 +909,11 @@ func TestValidateDomain(t *testing.T) {
|
|||||||
domain: "example.",
|
domain: "example.",
|
||||||
errFunc: require.NoError,
|
errFunc: require.NoError,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "Valid domain name with double hyphen",
|
||||||
|
domain: "xn--bcher-kva.com",
|
||||||
|
errFunc: require.NoError,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "Invalid wildcard domain name",
|
name: "Invalid wildcard domain name",
|
||||||
domain: "*.example",
|
domain: "*.example",
|
||||||
@@ -924,8 +930,8 @@ func TestValidateDomain(t *testing.T) {
|
|||||||
errFunc: require.Error,
|
errFunc: require.Error,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Invalid domain name with double hyphen",
|
name: "Invalid domain name with double dot",
|
||||||
domain: "test--example.com",
|
domain: "example..com",
|
||||||
errFunc: require.Error,
|
errFunc: require.Error,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -1009,7 +1015,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
|
|||||||
Port: nbdns.DefaultDNSPort,
|
Port: nbdns.DefaultDNSPort,
|
||||||
}},
|
}},
|
||||||
[]string{"groupA"},
|
[]string{"groupA"},
|
||||||
true, []string{}, true, userID, false,
|
true, domain.List{}, true, userID, false,
|
||||||
)
|
)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
@@ -1054,7 +1060,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
|
|||||||
Port: nbdns.DefaultDNSPort,
|
Port: nbdns.DefaultDNSPort,
|
||||||
}},
|
}},
|
||||||
[]string{"groupB"},
|
[]string{"groupB"},
|
||||||
true, []string{}, true, userID, false,
|
true, domain.List{}, true, userID, false,
|
||||||
)
|
)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
|||||||
@@ -1108,7 +1108,7 @@ func TestToSyncResponse(t *testing.T) {
|
|||||||
Port: nbdns.DefaultDNSPort,
|
Port: nbdns.DefaultDNSPort,
|
||||||
}},
|
}},
|
||||||
Primary: true,
|
Primary: true,
|
||||||
Domains: []string{"example.com"},
|
Domains: domain.List{"example.com"},
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
SearchDomainsEnabled: true,
|
SearchDomainsEnabled: true,
|
||||||
},
|
},
|
||||||
@@ -1121,7 +1121,7 @@ func TestToSyncResponse(t *testing.T) {
|
|||||||
}},
|
}},
|
||||||
Groups: []string{"group1"},
|
Groups: []string{"group1"},
|
||||||
Primary: true,
|
Primary: true,
|
||||||
Domains: []string{"example.com"},
|
Domains: domain.List{"example.com"},
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
SearchDomainsEnabled: true,
|
SearchDomainsEnabled: true,
|
||||||
},
|
},
|
||||||
@@ -1995,7 +1995,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
|||||||
Port: nbdns.DefaultDNSPort,
|
Port: nbdns.DefaultDNSPort,
|
||||||
}},
|
}},
|
||||||
[]string{"groupC"},
|
[]string{"groupC"},
|
||||||
true, []string{}, true, userID, false,
|
true, domain.List{}, true, userID, false,
|
||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user