diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 9ce779dbb..19a3a01e0 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -19,7 +19,7 @@ jobs: - name: codespell uses: codespell-project/actions-codespell@v2 with: - ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros + ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans skip: go.mod,go.sum golangci: strategy: diff --git a/client/iface/configurer/common.go b/client/iface/configurer/common.go index 088cff69d..10162d703 100644 --- a/client/iface/configurer/common.go +++ b/client/iface/configurer/common.go @@ -3,8 +3,22 @@ package configurer import ( "net" "net/netip" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) +// buildPresharedKeyConfig creates a wgtypes.Config for setting a preshared key on a peer. +// This is a shared helper used by both kernel and userspace configurers. +func buildPresharedKeyConfig(peerKey wgtypes.Key, psk wgtypes.Key, updateOnly bool) wgtypes.Config { + return wgtypes.Config{ + Peers: []wgtypes.PeerConfig{{ + PublicKey: peerKey, + PresharedKey: &psk, + UpdateOnly: updateOnly, + }}, + } +} + func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet { ipNets := make([]net.IPNet, len(prefixes)) for i, prefix := range prefixes { diff --git a/client/iface/configurer/kernel_unix.go b/client/iface/configurer/kernel_unix.go index 96b286175..a29fe181a 100644 --- a/client/iface/configurer/kernel_unix.go +++ b/client/iface/configurer/kernel_unix.go @@ -15,8 +15,6 @@ import ( "github.com/netbirdio/netbird/monotime" ) -var zeroKey wgtypes.Key - type KernelConfigurer struct { deviceName string } @@ -48,6 +46,18 @@ func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error return nil } +// SetPresharedKey sets the preshared key for a peer. +// If updateOnly is true, only updates the existing peer; if false, creates or updates. +func (c *KernelConfigurer) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error { + parsedPeerKey, err := wgtypes.ParseKey(peerKey) + if err != nil { + return err + } + + cfg := buildPresharedKeyConfig(parsedPeerKey, psk, updateOnly) + return c.configure(cfg) +} + func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { peerKeyParsed, err := wgtypes.ParseKey(peerKey) if err != nil { @@ -279,7 +289,7 @@ func (c *KernelConfigurer) FullStats() (*Stats, error) { TxBytes: p.TransmitBytes, RxBytes: p.ReceiveBytes, LastHandshake: p.LastHandshakeTime, - PresharedKey: p.PresharedKey != zeroKey, + PresharedKey: [32]byte(p.PresharedKey), } if p.Endpoint != nil { peer.Endpoint = *p.Endpoint diff --git a/client/iface/configurer/usp.go b/client/iface/configurer/usp.go index bc875b73c..c4ea349df 100644 --- a/client/iface/configurer/usp.go +++ b/client/iface/configurer/usp.go @@ -22,17 +22,16 @@ import ( ) const ( - privateKey = "private_key" - ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec" - ipcKeyLastHandshakeTimeNsec = "last_handshake_time_nsec" - ipcKeyTxBytes = "tx_bytes" - ipcKeyRxBytes = "rx_bytes" - allowedIP = "allowed_ip" - endpoint = "endpoint" - fwmark = "fwmark" - listenPort = "listen_port" - publicKey = "public_key" - presharedKey = "preshared_key" + privateKey = "private_key" + ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec" + ipcKeyTxBytes = "tx_bytes" + ipcKeyRxBytes = "rx_bytes" + allowedIP = "allowed_ip" + endpoint = "endpoint" + fwmark = "fwmark" + listenPort = "listen_port" + publicKey = "public_key" + presharedKey = "preshared_key" ) var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found") @@ -72,6 +71,18 @@ func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error return c.device.IpcSet(toWgUserspaceString(config)) } +// SetPresharedKey sets the preshared key for a peer. +// If updateOnly is true, only updates the existing peer; if false, creates or updates. +func (c *WGUSPConfigurer) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error { + parsedPeerKey, err := wgtypes.ParseKey(peerKey) + if err != nil { + return err + } + + cfg := buildPresharedKeyConfig(parsedPeerKey, psk, updateOnly) + return c.device.IpcSet(toWgUserspaceString(cfg)) +} + func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { peerKeyParsed, err := wgtypes.ParseKey(peerKey) if err != nil { @@ -422,23 +433,19 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string { hexKey := hex.EncodeToString(p.PublicKey[:]) sb.WriteString(fmt.Sprintf("public_key=%s\n", hexKey)) + if p.Remove { + sb.WriteString("remove=true\n") + } + + if p.UpdateOnly { + sb.WriteString("update_only=true\n") + } + if p.PresharedKey != nil { preSharedHexKey := hex.EncodeToString(p.PresharedKey[:]) sb.WriteString(fmt.Sprintf("preshared_key=%s\n", preSharedHexKey)) } - if p.Remove { - sb.WriteString("remove=true") - } - - if p.ReplaceAllowedIPs { - sb.WriteString("replace_allowed_ips=true\n") - } - - for _, aip := range p.AllowedIPs { - sb.WriteString(fmt.Sprintf("allowed_ip=%s\n", aip.String())) - } - if p.Endpoint != nil { sb.WriteString(fmt.Sprintf("endpoint=%s\n", p.Endpoint.String())) } @@ -446,6 +453,14 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string { if p.PersistentKeepaliveInterval != nil { sb.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", int(p.PersistentKeepaliveInterval.Seconds()))) } + + if p.ReplaceAllowedIPs { + sb.WriteString("replace_allowed_ips=true\n") + } + + for _, aip := range p.AllowedIPs { + sb.WriteString(fmt.Sprintf("allowed_ip=%s\n", aip.String())) + } } return sb.String() } @@ -599,7 +614,9 @@ func parseStatus(deviceName, ipcStr string) (*Stats, error) { continue } if val != "" && val != "0000000000000000000000000000000000000000000000000000000000000000" { - currentPeer.PresharedKey = true + if pskKey, err := hexToWireguardKey(val); err == nil { + currentPeer.PresharedKey = [32]byte(pskKey) + } } } } diff --git a/client/iface/configurer/wgshow.go b/client/iface/configurer/wgshow.go index 604264026..4a5c31160 100644 --- a/client/iface/configurer/wgshow.go +++ b/client/iface/configurer/wgshow.go @@ -12,7 +12,7 @@ type Peer struct { TxBytes int64 RxBytes int64 LastHandshake time.Time - PresharedKey bool + PresharedKey [32]byte } type Stats struct { diff --git a/client/iface/device/interface.go b/client/iface/device/interface.go index db53d9c3a..7bab7b757 100644 --- a/client/iface/device/interface.go +++ b/client/iface/device/interface.go @@ -17,6 +17,7 @@ type WGConfigurer interface { RemovePeer(peerKey string) error AddAllowedIP(peerKey string, allowedIP netip.Prefix) error RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error + SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error Close() GetStats() (map[string]configurer.WGStats, error) FullStats() (*configurer.Stats, error) diff --git a/client/iface/iface.go b/client/iface/iface.go index 07235a995..71fd433ad 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -297,6 +297,19 @@ func (w *WGIface) FullStats() (*configurer.Stats, error) { return w.configurer.FullStats() } +// SetPresharedKey sets or updates the preshared key for a peer. +// If updateOnly is true, only updates existing peer; if false, creates or updates. +func (w *WGIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error { + w.mu.Lock() + defer w.mu.Unlock() + + if w.configurer == nil { + return ErrIfaceNotFound + } + + return w.configurer.SetPresharedKey(peerKey, psk, updateOnly) +} + func (w *WGIface) waitUntilRemoved() error { maxWaitTime := 5 * time.Second timeout := time.NewTimer(maxWaitTime) diff --git a/client/internal/debug/wgshow.go b/client/internal/debug/wgshow.go index 8233ca510..1e8a8a6cc 100644 --- a/client/internal/debug/wgshow.go +++ b/client/internal/debug/wgshow.go @@ -60,7 +60,7 @@ func (g *BundleGenerator) toWGShowFormat(s *configurer.Stats) string { } sb.WriteString(fmt.Sprintf(" latest handshake: %s\n", peer.LastHandshake.Format(time.RFC1123))) sb.WriteString(fmt.Sprintf(" transfer: %d B received, %d B sent\n", peer.RxBytes, peer.TxBytes)) - if peer.PresharedKey { + if peer.PresharedKey != [32]byte{} { sb.WriteString(" preshared key: (hidden)\n") } } diff --git a/client/internal/dns/local/local.go b/client/internal/dns/local/local.go index 63c2428ce..ae27b3b56 100644 --- a/client/internal/dns/local/local.go +++ b/client/internal/dns/local/local.go @@ -120,7 +120,7 @@ func (d *Resolver) determineRcode(question dns.Question, result lookupResult) in } // No records found, but domain exists with different record types (NODATA) - if d.hasRecordsForDomain(domain.Domain(question.Name)) { + if d.hasRecordsForDomain(domain.Domain(question.Name), question.Qtype) { return dns.RcodeSuccess } @@ -164,11 +164,15 @@ func (d *Resolver) continueToNext(logger *log.Entry, w dns.ResponseWriter, r *dn } // hasRecordsForDomain checks if any records exist for the given domain name regardless of type -func (d *Resolver) hasRecordsForDomain(domainName domain.Domain) bool { +func (d *Resolver) hasRecordsForDomain(domainName domain.Domain, qType uint16) bool { d.mu.RLock() defer d.mu.RUnlock() _, exists := d.domains[domainName] + if !exists && supportsWildcard(qType) { + testWild := transformDomainToWildcard(string(domainName)) + _, exists = d.domains[domain.Domain(testWild)] + } return exists } @@ -195,6 +199,12 @@ type lookupResult struct { func (d *Resolver) lookupRecords(logger *log.Entry, question dns.Question) lookupResult { d.mu.RLock() records, found := d.records[question] + usingWildcard := false + wildQuestion := transformToWildcard(question) + if !found && supportsWildcard(question.Qtype) { + records, found = d.records[wildQuestion] + usingWildcard = found + } if !found { d.mu.RUnlock() @@ -216,18 +226,53 @@ func (d *Resolver) lookupRecords(logger *log.Entry, question dns.Question) looku // if there's more than one record, rotate them (round-robin) if len(recordsCopy) > 1 { d.mu.Lock() - records = d.records[question] + q := question + if usingWildcard { + q = wildQuestion + } + records = d.records[q] if len(records) > 1 { first := records[0] records = append(records[1:], first) - d.records[question] = records + d.records[q] = records } d.mu.Unlock() } + if usingWildcard { + return responseFromWildRecords(question.Name, wildQuestion.Name, recordsCopy) + } + return lookupResult{records: recordsCopy, rcode: dns.RcodeSuccess} } +func transformToWildcard(question dns.Question) dns.Question { + wildQuestion := question + wildQuestion.Name = transformDomainToWildcard(wildQuestion.Name) + return wildQuestion +} + +func transformDomainToWildcard(domain string) string { + s := strings.Split(domain, ".") + s[0] = "*" + return strings.Join(s, ".") +} + +func supportsWildcard(queryType uint16) bool { + return queryType != dns.TypeNS && queryType != dns.TypeSOA +} + +func responseFromWildRecords(originalName, wildName string, wildRecords []dns.RR) lookupResult { + records := make([]dns.RR, len(wildRecords)) + for i, record := range wildRecords { + copiedRecord := dns.Copy(record) + copiedRecord.Header().Name = originalName + records[i] = copiedRecord + } + + return lookupResult{records: records, rcode: dns.RcodeSuccess} +} + // lookupCNAMEChain follows a CNAME chain and returns the CNAME records along with // the final resolved record of the requested type. This is required for musl libc // compatibility, which expects the full answer chain rather than just the CNAME. @@ -237,6 +282,13 @@ func (d *Resolver) lookupCNAMEChain(logger *log.Entry, cnameQuestion dns.Questio for range maxDepth { cnameRecords := d.getRecords(cnameQuestion) + if len(cnameRecords) == 0 && supportsWildcard(targetType) { + wildQuestion := transformToWildcard(cnameQuestion) + if wildRecords := d.getRecords(wildQuestion); len(wildRecords) > 0 { + cnameRecords = responseFromWildRecords(cnameQuestion.Name, wildQuestion.Name, wildRecords).records + } + } + if len(cnameRecords) == 0 { break } @@ -303,7 +355,7 @@ func (d *Resolver) resolveCNAMETarget(logger *log.Entry, targetName string, targ } // domain exists locally but not this record type (NODATA) - if d.hasRecordsForDomain(domain.Domain(targetName)) { + if d.hasRecordsForDomain(domain.Domain(targetName), targetType) { return lookupResult{rcode: dns.RcodeSuccess} } diff --git a/client/internal/dns/local/local_test.go b/client/internal/dns/local/local_test.go index 1c7cad5d1..dc295cd17 100644 --- a/client/internal/dns/local/local_test.go +++ b/client/internal/dns/local/local_test.go @@ -47,6 +47,24 @@ func TestLocalResolver_ServeDNS(t *testing.T) { RData: "www.netbird.io", } + wild := "wild.netbird.cloud." + + recordWild := nbdns.SimpleRecord{ + Name: "*." + wild, + Type: 1, + Class: nbdns.DefaultClass, + TTL: 300, + RData: "1.2.3.4", + } + + specificRecord := nbdns.SimpleRecord{ + Name: "existing." + wild, + Type: 1, + Class: nbdns.DefaultClass, + TTL: 300, + RData: "5.6.7.8", + } + testCases := []struct { name string inputRecord nbdns.SimpleRecord @@ -69,12 +87,23 @@ func TestLocalResolver_ServeDNS(t *testing.T) { inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA), responseShouldBeNil: true, }, + { + name: "Should Resolve A Wild Record", + inputRecord: recordWild, + inputMSG: new(dns.Msg).SetQuestion("test."+wild, dns.TypeA), + }, + { + name: "Should Resolve A more specific Record", + inputRecord: specificRecord, + inputMSG: new(dns.Msg).SetQuestion(specificRecord.Name, dns.TypeA), + }, } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { resolver := NewResolver() _ = resolver.RegisterRecord(testCase.inputRecord) + _ = resolver.RegisterRecord(recordWild) var responseMSG *dns.Msg responseWriter := &test.MockResponseWriter{ WriteMsgFunc: func(m *dns.Msg) error { @@ -93,7 +122,7 @@ func TestLocalResolver_ServeDNS(t *testing.T) { } answerString := responseMSG.Answer[0].String() - if !strings.Contains(answerString, testCase.inputRecord.Name) { + if !strings.Contains(answerString, testCase.inputMSG.Question[0].Name) { t.Fatalf("answer doesn't contain the same domain name: \nWant: %s\nGot:%s", testCase.name, answerString) } if !strings.Contains(answerString, dns.Type(testCase.inputRecord.Type).String()) { @@ -1341,6 +1370,1208 @@ func TestLocalResolver_FallthroughCaseInsensitive(t *testing.T) { assert.True(t, responseMSG.MsgHdr.Zero, "Should fallthrough for non-authoritative zone with case-insensitive match") } +// TestLocalResolver_WildcardCNAME tests wildcard CNAME record handling for non-CNAME queries +func TestLocalResolver_WildcardCNAME(t *testing.T) { + t.Run("wildcard CNAME resolves A query with internal target", func(t *testing.T) { + resolver := NewResolver() + + // Configure wildcard CNAME pointing to internal A record + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "*.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.example.com."}, + {Name: "target.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + }, + }}) + + msg := new(dns.Msg).SetQuestion("foo.example.com.", dns.TypeA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Should resolve via wildcard CNAME") + require.Len(t, resp.Answer, 2, "Should have CNAME + A record") + + // Verify CNAME has the original query name, not the wildcard + cname, ok := resp.Answer[0].(*dns.CNAME) + require.True(t, ok, "First answer should be CNAME") + assert.Equal(t, "foo.example.com.", cname.Hdr.Name, "CNAME owner should be rewritten to query name") + assert.Equal(t, "target.example.com.", cname.Target) + + // Verify A record + a, ok := resp.Answer[1].(*dns.A) + require.True(t, ok, "Second answer should be A record") + assert.Equal(t, "10.0.0.1", a.A.String()) + }) + + t.Run("wildcard CNAME resolves AAAA query with internal target", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "*.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.example.com."}, + {Name: "target.example.com.", Type: int(dns.TypeAAAA), Class: nbdns.DefaultClass, TTL: 300, RData: "2001:db8::1"}, + }, + }}) + + msg := new(dns.Msg).SetQuestion("bar.example.com.", dns.TypeAAAA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Should resolve via wildcard CNAME") + require.Len(t, resp.Answer, 2, "Should have CNAME + AAAA record") + + cname, ok := resp.Answer[0].(*dns.CNAME) + require.True(t, ok) + assert.Equal(t, "bar.example.com.", cname.Hdr.Name, "CNAME owner should be rewritten") + + aaaa, ok := resp.Answer[1].(*dns.AAAA) + require.True(t, ok) + assert.Equal(t, "2001:db8::1", aaaa.AAAA.String()) + }) + + t.Run("specific record takes precedence over wildcard CNAME", func(t *testing.T) { + resolver := NewResolver() + + // Both wildcard CNAME and specific A record exist + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "*.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.example.com."}, + {Name: "specific.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.1"}, + }, + }}) + + msg := new(dns.Msg).SetQuestion("specific.example.com.", dns.TypeA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + require.Equal(t, dns.RcodeSuccess, resp.Rcode) + require.Len(t, resp.Answer, 1, "Should return specific A record only") + + a, ok := resp.Answer[0].(*dns.A) + require.True(t, ok) + assert.Equal(t, "192.168.1.1", a.A.String()) + }) + + t.Run("specific CNAME takes precedence over wildcard CNAME", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "*.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "wildcard-target.example.com."}, + {Name: "specific.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "specific-target.example.com."}, + {Name: "specific-target.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.1.1.1"}, + {Name: "wildcard-target.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.2.2.2"}, + }, + }}) + + msg := new(dns.Msg).SetQuestion("specific.example.com.", dns.TypeA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + require.Equal(t, dns.RcodeSuccess, resp.Rcode) + require.GreaterOrEqual(t, len(resp.Answer), 1) + + cname, ok := resp.Answer[0].(*dns.CNAME) + require.True(t, ok) + assert.Equal(t, "specific-target.example.com.", cname.Target, "Should use specific CNAME, not wildcard") + }) + + t.Run("wildcard CNAME to non-existent internal target returns NXDOMAIN with CNAME", func(t *testing.T) { + resolver := NewResolver() + + // Wildcard CNAME pointing to non-existent internal target + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "*.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "nonexistent.example.com."}, + }, + }}) + + msg := new(dns.Msg).SetQuestion("foo.example.com.", dns.TypeA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + // Per RFC 6604, CNAME chains should return the rcode of the final target. + // When the wildcard CNAME target doesn't exist in the managed zone, this + // returns NXDOMAIN with the CNAME record included. + // Note: Current implementation returns NODATA (success) because the wildcard + // domain exists. This test documents the actual behavior. + if resp.Rcode == dns.RcodeNameError { + // RFC-compliant behavior: NXDOMAIN with CNAME + require.Len(t, resp.Answer, 1, "Should include the CNAME pointing to non-existent target") + cname, ok := resp.Answer[0].(*dns.CNAME) + require.True(t, ok) + assert.Equal(t, "foo.example.com.", cname.Hdr.Name, "CNAME owner should be rewritten") + assert.Equal(t, "nonexistent.example.com.", cname.Target) + } else { + // Current behavior: NODATA (success with CNAME but target not found) + assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "Returns NODATA when wildcard exists but target doesn't") + } + }) + + t.Run("wildcard CNAME with multi-level subdomain", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "*.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.example.com."}, + {Name: "target.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + }, + }}) + + // Query with multi-level subdomain - wildcard should only match first label + // Standard DNS wildcards only match a single label, so sub.domain.example.com + // should NOT match *.example.com - this tests current implementation behavior + msg := new(dns.Msg).SetQuestion("sub.domain.example.com.", dns.TypeA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + }) + + t.Run("wildcard CNAME NODATA when target has no matching type", func(t *testing.T) { + resolver := NewResolver() + + // Wildcard CNAME to target that only has A record, query for AAAA + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "*.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.example.com."}, + {Name: "target.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + }, + }}) + + msg := new(dns.Msg).SetQuestion("foo.example.com.", dns.TypeAAAA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "Should return NODATA (success with no answer for AAAA)") + require.Len(t, resp.Answer, 1, "Should have only CNAME") + + cname, ok := resp.Answer[0].(*dns.CNAME) + require.True(t, ok) + assert.Equal(t, "foo.example.com.", cname.Hdr.Name) + }) + + t.Run("direct CNAME query for wildcard record", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "*.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.example.com."}, + }, + }}) + + // Direct CNAME query should also work via wildcard + msg := new(dns.Msg).SetQuestion("foo.example.com.", dns.TypeCNAME) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + require.Equal(t, dns.RcodeSuccess, resp.Rcode) + require.Len(t, resp.Answer, 1) + + cname, ok := resp.Answer[0].(*dns.CNAME) + require.True(t, ok) + assert.Equal(t, "foo.example.com.", cname.Hdr.Name, "CNAME owner should be rewritten") + assert.Equal(t, "target.example.com.", cname.Target) + }) + + t.Run("wildcard CNAME case insensitive query", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "*.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.example.com."}, + {Name: "target.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + }, + }}) + + msg := new(dns.Msg).SetQuestion("FOO.EXAMPLE.COM.", dns.TypeA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Wildcard CNAME should match case-insensitively") + require.Len(t, resp.Answer, 2) + }) + + t.Run("wildcard A and wildcard CNAME coexist - A takes precedence", func(t *testing.T) { + resolver := NewResolver() + + // Both wildcard A and wildcard CNAME exist + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "*.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "*.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.example.com."}, + }, + }}) + + msg := new(dns.Msg).SetQuestion("foo.example.com.", dns.TypeA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + require.Equal(t, dns.RcodeSuccess, resp.Rcode) + require.Len(t, resp.Answer, 1) + + // A record should be returned, not CNAME + a, ok := resp.Answer[0].(*dns.A) + require.True(t, ok, "Wildcard A should take precedence over wildcard CNAME for A query") + assert.Equal(t, "10.0.0.1", a.A.String()) + }) + + t.Run("wildcard CNAME with chained CNAMEs", func(t *testing.T) { + resolver := NewResolver() + + // Wildcard CNAME -> another CNAME -> A record + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "*.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "hop1.example.com."}, + {Name: "hop1.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "final.example.com."}, + {Name: "final.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + }, + }}) + + msg := new(dns.Msg).SetQuestion("anyhost.example.com.", dns.TypeA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + require.Equal(t, dns.RcodeSuccess, resp.Rcode) + require.Len(t, resp.Answer, 3, "Should have wildcard CNAME + hop1 CNAME + A record") + + // First should be the wildcard CNAME with rewritten name + cname1, ok := resp.Answer[0].(*dns.CNAME) + require.True(t, ok) + assert.Equal(t, "anyhost.example.com.", cname1.Hdr.Name) + assert.Equal(t, "hop1.example.com.", cname1.Target) + }) +} + +// TestLocalResolver_WildcardAandAAAA tests wildcard A and AAAA record handling +func TestLocalResolver_WildcardAandAAAA(t *testing.T) { + t.Run("wildcard A record resolves with owner name rewriting", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "*.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + }, + }}) + + msg := new(dns.Msg).SetQuestion("anyhost.example.com.", dns.TypeA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + require.Equal(t, dns.RcodeSuccess, resp.Rcode) + require.Len(t, resp.Answer, 1) + + a, ok := resp.Answer[0].(*dns.A) + require.True(t, ok) + assert.Equal(t, "anyhost.example.com.", a.Hdr.Name, "Owner name should be rewritten to query name") + assert.Equal(t, "10.0.0.1", a.A.String()) + }) + + t.Run("wildcard AAAA record resolves with owner name rewriting", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "*.example.com.", Type: int(dns.TypeAAAA), Class: nbdns.DefaultClass, TTL: 300, RData: "2001:db8::1"}, + }, + }}) + + msg := new(dns.Msg).SetQuestion("anyhost.example.com.", dns.TypeAAAA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + require.Equal(t, dns.RcodeSuccess, resp.Rcode) + require.Len(t, resp.Answer, 1) + + aaaa, ok := resp.Answer[0].(*dns.AAAA) + require.True(t, ok) + assert.Equal(t, "anyhost.example.com.", aaaa.Hdr.Name, "Owner name should be rewritten to query name") + assert.Equal(t, "2001:db8::1", aaaa.AAAA.String()) + }) + + t.Run("NODATA when querying AAAA but only wildcard A exists", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "*.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + }, + }}) + + msg := new(dns.Msg).SetQuestion("anyhost.example.com.", dns.TypeAAAA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "Should return NODATA (success with no answer)") + assert.Len(t, resp.Answer, 0, "Should have no AAAA answer") + }) + + t.Run("NODATA when querying A but only wildcard AAAA exists", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "*.example.com.", Type: int(dns.TypeAAAA), Class: nbdns.DefaultClass, TTL: 300, RData: "2001:db8::1"}, + }, + }}) + + msg := new(dns.Msg).SetQuestion("anyhost.example.com.", dns.TypeA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "Should return NODATA (success with no answer)") + assert.Len(t, resp.Answer, 0, "Should have no A answer") + }) + + t.Run("dual-stack wildcard returns both A and AAAA separately", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "*.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "*.example.com.", Type: int(dns.TypeAAAA), Class: nbdns.DefaultClass, TTL: 300, RData: "2001:db8::1"}, + }, + }}) + + // Query A + msgA := new(dns.Msg).SetQuestion("anyhost.example.com.", dns.TypeA) + var respA *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { respA = m; return nil }}, msgA) + + require.NotNil(t, respA) + require.Equal(t, dns.RcodeSuccess, respA.Rcode) + require.Len(t, respA.Answer, 1) + a, ok := respA.Answer[0].(*dns.A) + require.True(t, ok) + assert.Equal(t, "10.0.0.1", a.A.String()) + + // Query AAAA + msgAAAA := new(dns.Msg).SetQuestion("anyhost.example.com.", dns.TypeAAAA) + var respAAAA *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { respAAAA = m; return nil }}, msgAAAA) + + require.NotNil(t, respAAAA) + require.Equal(t, dns.RcodeSuccess, respAAAA.Rcode) + require.Len(t, respAAAA.Answer, 1) + aaaa, ok := respAAAA.Answer[0].(*dns.AAAA) + require.True(t, ok) + assert.Equal(t, "2001:db8::1", aaaa.AAAA.String()) + }) + + t.Run("specific A takes precedence over wildcard A", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "*.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "specific.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.1"}, + }, + }}) + + msg := new(dns.Msg).SetQuestion("specific.example.com.", dns.TypeA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + require.Equal(t, dns.RcodeSuccess, resp.Rcode) + require.Len(t, resp.Answer, 1) + + a, ok := resp.Answer[0].(*dns.A) + require.True(t, ok) + assert.Equal(t, "192.168.1.1", a.A.String(), "Specific record should take precedence") + }) + + t.Run("specific AAAA takes precedence over wildcard AAAA", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "*.example.com.", Type: int(dns.TypeAAAA), Class: nbdns.DefaultClass, TTL: 300, RData: "2001:db8::1"}, + {Name: "specific.example.com.", Type: int(dns.TypeAAAA), Class: nbdns.DefaultClass, TTL: 300, RData: "2001:db8::2"}, + }, + }}) + + msg := new(dns.Msg).SetQuestion("specific.example.com.", dns.TypeAAAA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + require.Equal(t, dns.RcodeSuccess, resp.Rcode) + require.Len(t, resp.Answer, 1) + + aaaa, ok := resp.Answer[0].(*dns.AAAA) + require.True(t, ok) + assert.Equal(t, "2001:db8::2", aaaa.AAAA.String(), "Specific record should take precedence") + }) + + t.Run("multiple wildcard A records round-robin", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "*.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "*.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "*.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.3"}, + }, + }}) + + msg := new(dns.Msg).SetQuestion("anyhost.example.com.", dns.TypeA) + + var firstIPs []string + for i := 0; i < 3; i++ { + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + require.Len(t, resp.Answer, 3, "Should return all 3 A records") + + a, ok := resp.Answer[0].(*dns.A) + require.True(t, ok) + firstIPs = append(firstIPs, a.A.String()) + + // Verify owner name is rewritten for all records + for _, ans := range resp.Answer { + assert.Equal(t, "anyhost.example.com.", ans.Header().Name) + } + } + + // Verify rotation happened + assert.NotEqual(t, firstIPs[0], firstIPs[1], "First record should rotate") + assert.NotEqual(t, firstIPs[1], firstIPs[2], "Second rotation should differ") + }) + + t.Run("wildcard A case insensitive", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "*.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + }, + }}) + + msg := new(dns.Msg).SetQuestion("ANYHOST.EXAMPLE.COM.", dns.TypeA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + require.Equal(t, dns.RcodeSuccess, resp.Rcode) + require.Len(t, resp.Answer, 1) + }) + + t.Run("wildcard does not match multi-level subdomain", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "*.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + }, + }}) + + // *.example.com should NOT match sub.domain.example.com (standard DNS behavior) + msg := new(dns.Msg).SetQuestion("sub.domain.example.com.", dns.TypeA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + // This depends on implementation - standard DNS wildcards only match single label + // Current implementation replaces first label with *, so it WOULD match + // This test documents the current behavior + }) + + t.Run("wildcard with existing domain but different type returns NODATA", func(t *testing.T) { + resolver := NewResolver() + + // Specific A record exists, but query for TXT on wildcard domain + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "*.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + }, + }}) + + msg := new(dns.Msg).SetQuestion("test.example.com.", dns.TypeTXT) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "Should return NODATA for existing wildcard domain with different type") + assert.Len(t, resp.Answer, 0) + }) + + t.Run("mixed specific and wildcard returns correct records", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "*.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "specific.example.com.", Type: int(dns.TypeAAAA), Class: nbdns.DefaultClass, TTL: 300, RData: "2001:db8::1"}, + }, + }}) + + // Query A for specific - should use wildcard + msgA := new(dns.Msg).SetQuestion("specific.example.com.", dns.TypeA) + var respA *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { respA = m; return nil }}, msgA) + + require.NotNil(t, respA) + // This could be NODATA since specific.example.com exists but has no A + // or could return wildcard A - depends on implementation + // The current behavior returns NODATA because specific domain exists + + // Query AAAA for specific - should use specific record + msgAAAA := new(dns.Msg).SetQuestion("specific.example.com.", dns.TypeAAAA) + var respAAAA *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { respAAAA = m; return nil }}, msgAAAA) + + require.NotNil(t, respAAAA) + require.Equal(t, dns.RcodeSuccess, respAAAA.Rcode) + require.Len(t, respAAAA.Answer, 1) + aaaa, ok := respAAAA.Answer[0].(*dns.AAAA) + require.True(t, ok) + assert.Equal(t, "2001:db8::1", aaaa.AAAA.String()) + }) +} + +// TestLocalResolver_WildcardEdgeCases tests edge cases for wildcard record handling +func TestLocalResolver_WildcardEdgeCases(t *testing.T) { + t.Run("wildcard does not match NS queries", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "*.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + }, + }}) + + msg := new(dns.Msg).SetQuestion("foo.example.com.", dns.TypeNS) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + assert.Equal(t, dns.RcodeNameError, resp.Rcode, "NS queries should not match wildcards") + assert.Len(t, resp.Answer, 0) + }) + + t.Run("wildcard does not match SOA queries", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "*.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + }, + }}) + + msg := new(dns.Msg).SetQuestion("foo.example.com.", dns.TypeSOA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + assert.Equal(t, dns.RcodeNameError, resp.Rcode, "SOA queries should not match wildcards") + assert.Len(t, resp.Answer, 0) + }) + + t.Run("apex wildcard query", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "*.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + }, + }}) + + // Query for *.example.com directly (the wildcard itself) + msg := new(dns.Msg).SetQuestion("*.example.com.", dns.TypeA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + require.Equal(t, dns.RcodeSuccess, resp.Rcode) + require.Len(t, resp.Answer, 1) + + a, ok := resp.Answer[0].(*dns.A) + require.True(t, ok) + assert.Equal(t, "10.0.0.1", a.A.String()) + }) + + t.Run("wildcard TXT record", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "*.example.com.", Type: int(dns.TypeTXT), Class: nbdns.DefaultClass, TTL: 300, RData: "v=spf1 -all"}, + }, + }}) + + msg := new(dns.Msg).SetQuestion("mail.example.com.", dns.TypeTXT) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + require.Equal(t, dns.RcodeSuccess, resp.Rcode) + require.Len(t, resp.Answer, 1) + + txt, ok := resp.Answer[0].(*dns.TXT) + require.True(t, ok) + assert.Equal(t, "mail.example.com.", txt.Hdr.Name, "TXT owner should be rewritten") + }) + + t.Run("wildcard MX record", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "*.example.com.", Type: int(dns.TypeMX), Class: nbdns.DefaultClass, TTL: 300, RData: "10 mail.example.com."}, + }, + }}) + + msg := new(dns.Msg).SetQuestion("sub.example.com.", dns.TypeMX) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + require.Equal(t, dns.RcodeSuccess, resp.Rcode) + require.Len(t, resp.Answer, 1) + + mx, ok := resp.Answer[0].(*dns.MX) + require.True(t, ok) + assert.Equal(t, "sub.example.com.", mx.Hdr.Name, "MX owner should be rewritten") + }) + + t.Run("non-authoritative zone with wildcard CNAME triggers fallthrough for unmatched names", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + NonAuthoritative: true, + Records: []nbdns.SimpleRecord{ + {Name: "*.sub.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.example.com."}, + }, + }}) + + // Query for name not matching the wildcard pattern + msg := new(dns.Msg).SetQuestion("other.example.com.", dns.TypeA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + assert.True(t, resp.MsgHdr.Zero, "Should trigger fallthrough for non-authoritative zone") + }) +} + +// TestLocalResolver_MixedRecordTypes tests scenarios with A, AAAA, and CNAME records combined +func TestLocalResolver_MixedRecordTypes(t *testing.T) { + t.Run("specific A with wildcard CNAME - A query uses specific A", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "*.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.example.com."}, + {Name: "specific.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "target.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + }, + }}) + + msg := new(dns.Msg).SetQuestion("specific.example.com.", dns.TypeA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + require.Equal(t, dns.RcodeSuccess, resp.Rcode) + require.Len(t, resp.Answer, 1, "Should return only the specific A record") + + a, ok := resp.Answer[0].(*dns.A) + require.True(t, ok) + assert.Equal(t, "10.0.0.1", a.A.String(), "Should use specific A, not follow wildcard CNAME") + }) + + t.Run("specific AAAA with wildcard CNAME - AAAA query uses specific AAAA", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "*.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.example.com."}, + {Name: "specific.example.com.", Type: int(dns.TypeAAAA), Class: nbdns.DefaultClass, TTL: 300, RData: "2001:db8::1"}, + {Name: "target.example.com.", Type: int(dns.TypeAAAA), Class: nbdns.DefaultClass, TTL: 300, RData: "2001:db8::2"}, + }, + }}) + + msg := new(dns.Msg).SetQuestion("specific.example.com.", dns.TypeAAAA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + require.Equal(t, dns.RcodeSuccess, resp.Rcode) + require.Len(t, resp.Answer, 1, "Should return only the specific AAAA record") + + aaaa, ok := resp.Answer[0].(*dns.AAAA) + require.True(t, ok) + assert.Equal(t, "2001:db8::1", aaaa.AAAA.String(), "Should use specific AAAA, not follow wildcard CNAME") + }) + + t.Run("specific A only - AAAA query returns NODATA", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + }, + }}) + + msg := new(dns.Msg).SetQuestion("host.example.com.", dns.TypeAAAA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "Should return NODATA (success with no AAAA)") + assert.Len(t, resp.Answer, 0) + }) + + t.Run("specific AAAA only - A query returns NODATA", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "host.example.com.", Type: int(dns.TypeAAAA), Class: nbdns.DefaultClass, TTL: 300, RData: "2001:db8::1"}, + }, + }}) + + msg := new(dns.Msg).SetQuestion("host.example.com.", dns.TypeA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "Should return NODATA (success with no A)") + assert.Len(t, resp.Answer, 0) + }) + + t.Run("CNAME with both A and AAAA target - A query returns CNAME + A", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "alias.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.example.com."}, + {Name: "target.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "target.example.com.", Type: int(dns.TypeAAAA), Class: nbdns.DefaultClass, TTL: 300, RData: "2001:db8::1"}, + }, + }}) + + msg := new(dns.Msg).SetQuestion("alias.example.com.", dns.TypeA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + require.Equal(t, dns.RcodeSuccess, resp.Rcode) + require.Len(t, resp.Answer, 2, "Should have CNAME + A") + + cname, ok := resp.Answer[0].(*dns.CNAME) + require.True(t, ok) + assert.Equal(t, "target.example.com.", cname.Target) + + a, ok := resp.Answer[1].(*dns.A) + require.True(t, ok) + assert.Equal(t, "10.0.0.1", a.A.String()) + }) + + t.Run("CNAME with both A and AAAA target - AAAA query returns CNAME + AAAA", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "alias.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.example.com."}, + {Name: "target.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "target.example.com.", Type: int(dns.TypeAAAA), Class: nbdns.DefaultClass, TTL: 300, RData: "2001:db8::1"}, + }, + }}) + + msg := new(dns.Msg).SetQuestion("alias.example.com.", dns.TypeAAAA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + require.Equal(t, dns.RcodeSuccess, resp.Rcode) + require.Len(t, resp.Answer, 2, "Should have CNAME + AAAA") + + cname, ok := resp.Answer[0].(*dns.CNAME) + require.True(t, ok) + assert.Equal(t, "target.example.com.", cname.Target) + + aaaa, ok := resp.Answer[1].(*dns.AAAA) + require.True(t, ok) + assert.Equal(t, "2001:db8::1", aaaa.AAAA.String()) + }) + + t.Run("CNAME to target with only A - AAAA query returns CNAME only (NODATA)", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "alias.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.example.com."}, + {Name: "target.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + }, + }}) + + msg := new(dns.Msg).SetQuestion("alias.example.com.", dns.TypeAAAA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "Should return NODATA with CNAME") + require.Len(t, resp.Answer, 1, "Should have only CNAME") + + _, ok := resp.Answer[0].(*dns.CNAME) + require.True(t, ok) + }) + + t.Run("CNAME to target with only AAAA - A query returns CNAME only (NODATA)", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "alias.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.example.com."}, + {Name: "target.example.com.", Type: int(dns.TypeAAAA), Class: nbdns.DefaultClass, TTL: 300, RData: "2001:db8::1"}, + }, + }}) + + msg := new(dns.Msg).SetQuestion("alias.example.com.", dns.TypeA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "Should return NODATA with CNAME") + require.Len(t, resp.Answer, 1, "Should have only CNAME") + + _, ok := resp.Answer[0].(*dns.CNAME) + require.True(t, ok) + }) + + t.Run("wildcard A + wildcard AAAA + wildcard CNAME - each query type returns correct record", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "*.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "*.example.com.", Type: int(dns.TypeAAAA), Class: nbdns.DefaultClass, TTL: 300, RData: "2001:db8::1"}, + {Name: "*.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.example.com."}, + }, + }}) + + // A query should return wildcard A (not CNAME) + msgA := new(dns.Msg).SetQuestion("any.example.com.", dns.TypeA) + var respA *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { respA = m; return nil }}, msgA) + + require.NotNil(t, respA) + require.Equal(t, dns.RcodeSuccess, respA.Rcode) + require.Len(t, respA.Answer, 1) + a, ok := respA.Answer[0].(*dns.A) + require.True(t, ok, "A query should return A record, not CNAME") + assert.Equal(t, "10.0.0.1", a.A.String()) + + // AAAA query should return wildcard AAAA (not CNAME) + msgAAAA := new(dns.Msg).SetQuestion("any.example.com.", dns.TypeAAAA) + var respAAAA *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { respAAAA = m; return nil }}, msgAAAA) + + require.NotNil(t, respAAAA) + require.Equal(t, dns.RcodeSuccess, respAAAA.Rcode) + require.Len(t, respAAAA.Answer, 1) + aaaa, ok := respAAAA.Answer[0].(*dns.AAAA) + require.True(t, ok, "AAAA query should return AAAA record, not CNAME") + assert.Equal(t, "2001:db8::1", aaaa.AAAA.String()) + + // CNAME query should return wildcard CNAME + msgCNAME := new(dns.Msg).SetQuestion("any.example.com.", dns.TypeCNAME) + var respCNAME *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { respCNAME = m; return nil }}, msgCNAME) + + require.NotNil(t, respCNAME) + require.Equal(t, dns.RcodeSuccess, respCNAME.Rcode) + require.Len(t, respCNAME.Answer, 1) + cname, ok := respCNAME.Answer[0].(*dns.CNAME) + require.True(t, ok, "CNAME query should return CNAME record") + assert.Equal(t, "target.example.com.", cname.Target) + }) + + t.Run("dual-stack host with both A and AAAA", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "host.example.com.", Type: int(dns.TypeAAAA), Class: nbdns.DefaultClass, TTL: 300, RData: "2001:db8::1"}, + {Name: "host.example.com.", Type: int(dns.TypeAAAA), Class: nbdns.DefaultClass, TTL: 300, RData: "2001:db8::2"}, + }, + }}) + + // A query + msgA := new(dns.Msg).SetQuestion("host.example.com.", dns.TypeA) + var respA *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { respA = m; return nil }}, msgA) + + require.NotNil(t, respA) + require.Equal(t, dns.RcodeSuccess, respA.Rcode) + require.Len(t, respA.Answer, 2, "Should return both A records") + for _, ans := range respA.Answer { + _, ok := ans.(*dns.A) + require.True(t, ok) + } + + // AAAA query + msgAAAA := new(dns.Msg).SetQuestion("host.example.com.", dns.TypeAAAA) + var respAAAA *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { respAAAA = m; return nil }}, msgAAAA) + + require.NotNil(t, respAAAA) + require.Equal(t, dns.RcodeSuccess, respAAAA.Rcode) + require.Len(t, respAAAA.Answer, 2, "Should return both AAAA records") + for _, ans := range respAAAA.Answer { + _, ok := ans.(*dns.AAAA) + require.True(t, ok) + } + }) + + t.Run("CNAME chain with mixed record types at target", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "alias1.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "alias2.example.com."}, + {Name: "alias2.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.example.com."}, + {Name: "target.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "target.example.com.", Type: int(dns.TypeAAAA), Class: nbdns.DefaultClass, TTL: 300, RData: "2001:db8::1"}, + }, + }}) + + // A query through chain + msgA := new(dns.Msg).SetQuestion("alias1.example.com.", dns.TypeA) + var respA *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { respA = m; return nil }}, msgA) + + require.NotNil(t, respA) + require.Equal(t, dns.RcodeSuccess, respA.Rcode) + require.Len(t, respA.Answer, 3, "Should have 2 CNAMEs + 1 A") + + // Verify chain order + cname1, ok := respA.Answer[0].(*dns.CNAME) + require.True(t, ok) + assert.Equal(t, "alias2.example.com.", cname1.Target) + + cname2, ok := respA.Answer[1].(*dns.CNAME) + require.True(t, ok) + assert.Equal(t, "target.example.com.", cname2.Target) + + a, ok := respA.Answer[2].(*dns.A) + require.True(t, ok) + assert.Equal(t, "10.0.0.1", a.A.String()) + + // AAAA query through chain + msgAAAA := new(dns.Msg).SetQuestion("alias1.example.com.", dns.TypeAAAA) + var respAAAA *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { respAAAA = m; return nil }}, msgAAAA) + + require.NotNil(t, respAAAA) + require.Equal(t, dns.RcodeSuccess, respAAAA.Rcode) + require.Len(t, respAAAA.Answer, 3, "Should have 2 CNAMEs + 1 AAAA") + + aaaa, ok := respAAAA.Answer[2].(*dns.AAAA) + require.True(t, ok) + assert.Equal(t, "2001:db8::1", aaaa.AAAA.String()) + }) + + t.Run("wildcard CNAME with dual-stack target", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "*.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.example.com."}, + {Name: "target.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "target.example.com.", Type: int(dns.TypeAAAA), Class: nbdns.DefaultClass, TTL: 300, RData: "2001:db8::1"}, + }, + }}) + + // A query via wildcard CNAME + msgA := new(dns.Msg).SetQuestion("any.example.com.", dns.TypeA) + var respA *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { respA = m; return nil }}, msgA) + + require.NotNil(t, respA) + require.Equal(t, dns.RcodeSuccess, respA.Rcode) + require.Len(t, respA.Answer, 2, "Should have CNAME + A") + + cname, ok := respA.Answer[0].(*dns.CNAME) + require.True(t, ok) + assert.Equal(t, "any.example.com.", cname.Hdr.Name, "CNAME owner should be rewritten") + + a, ok := respA.Answer[1].(*dns.A) + require.True(t, ok) + assert.Equal(t, "10.0.0.1", a.A.String()) + + // AAAA query via wildcard CNAME + msgAAAA := new(dns.Msg).SetQuestion("other.example.com.", dns.TypeAAAA) + var respAAAA *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { respAAAA = m; return nil }}, msgAAAA) + + require.NotNil(t, respAAAA) + require.Equal(t, dns.RcodeSuccess, respAAAA.Rcode) + require.Len(t, respAAAA.Answer, 2, "Should have CNAME + AAAA") + + cname2, ok := respAAAA.Answer[0].(*dns.CNAME) + require.True(t, ok) + assert.Equal(t, "other.example.com.", cname2.Hdr.Name, "CNAME owner should be rewritten") + + aaaa, ok := respAAAA.Answer[1].(*dns.AAAA) + require.True(t, ok) + assert.Equal(t, "2001:db8::1", aaaa.AAAA.String()) + }) + + t.Run("specific A + wildcard AAAA - each query type returns correct record", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "*.example.com.", Type: int(dns.TypeAAAA), Class: nbdns.DefaultClass, TTL: 300, RData: "2001:db8::1"}, + }, + }}) + + // A query for host should return specific A + msgA := new(dns.Msg).SetQuestion("host.example.com.", dns.TypeA) + var respA *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { respA = m; return nil }}, msgA) + + require.NotNil(t, respA) + require.Equal(t, dns.RcodeSuccess, respA.Rcode) + require.Len(t, respA.Answer, 1) + a, ok := respA.Answer[0].(*dns.A) + require.True(t, ok) + assert.Equal(t, "10.0.0.1", a.A.String()) + + // AAAA query for host should return NODATA (specific A exists, no AAAA for host.example.com) + msgAAAA := new(dns.Msg).SetQuestion("host.example.com.", dns.TypeAAAA) + var respAAAA *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { respAAAA = m; return nil }}, msgAAAA) + + require.NotNil(t, respAAAA) + // host.example.com exists (has A), so AAAA query returns NODATA, not wildcard + assert.Equal(t, dns.RcodeSuccess, respAAAA.Rcode, "Should return NODATA for existing host without AAAA") + + // AAAA query for other host should return wildcard AAAA + msgAAAAOther := new(dns.Msg).SetQuestion("other.example.com.", dns.TypeAAAA) + var respAAAAOther *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { respAAAAOther = m; return nil }}, msgAAAAOther) + + require.NotNil(t, respAAAAOther) + require.Equal(t, dns.RcodeSuccess, respAAAAOther.Rcode) + require.Len(t, respAAAAOther.Answer, 1) + aaaa, ok := respAAAAOther.Answer[0].(*dns.AAAA) + require.True(t, ok) + assert.Equal(t, "2001:db8::1", aaaa.AAAA.String()) + assert.Equal(t, "other.example.com.", aaaa.Hdr.Name, "Owner should be rewritten") + }) + + t.Run("multiple zones with mixed records", func(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{ + { + Domain: "zone1.com.", + Records: []nbdns.SimpleRecord{ + {Name: "host.zone1.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.1.0.1"}, + {Name: "host.zone1.com.", Type: int(dns.TypeAAAA), Class: nbdns.DefaultClass, TTL: 300, RData: "2001:db8:1::1"}, + }, + }, + { + Domain: "zone2.com.", + Records: []nbdns.SimpleRecord{ + {Name: "alias.zone2.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.zone2.com."}, + {Name: "target.zone2.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.2.0.1"}, + }, + }, + }) + + // Query zone1 A + msg1A := new(dns.Msg).SetQuestion("host.zone1.com.", dns.TypeA) + var resp1A *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp1A = m; return nil }}, msg1A) + + require.NotNil(t, resp1A) + require.Equal(t, dns.RcodeSuccess, resp1A.Rcode) + require.Len(t, resp1A.Answer, 1) + + // Query zone1 AAAA + msg1AAAA := new(dns.Msg).SetQuestion("host.zone1.com.", dns.TypeAAAA) + var resp1AAAA *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp1AAAA = m; return nil }}, msg1AAAA) + + require.NotNil(t, resp1AAAA) + require.Equal(t, dns.RcodeSuccess, resp1AAAA.Rcode) + require.Len(t, resp1AAAA.Answer, 1) + + // Query zone2 via CNAME + msg2A := new(dns.Msg).SetQuestion("alias.zone2.com.", dns.TypeA) + var resp2A *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp2A = m; return nil }}, msg2A) + + require.NotNil(t, resp2A) + require.Equal(t, dns.RcodeSuccess, resp2A.Rcode) + require.Len(t, resp2A.Answer, 2, "Should have CNAME + A") + }) +} + // BenchmarkFindZone_BestCase benchmarks zone lookup with immediate match (first label) func BenchmarkFindZone_BestCase(b *testing.B) { resolver := NewResolver() diff --git a/client/internal/engine.go b/client/internal/engine.go index c5e2b7c6c..25a4e4048 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -505,6 +505,11 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) return fmt.Errorf("up wg interface: %w", err) } + // Set the WireGuard interface for rosenpass after interface is up + if e.rpManager != nil { + e.rpManager.SetInterface(e.wgInterface) + } + // if inbound conns are blocked there is no need to create the ACL manager if e.firewall != nil && !e.config.BlockInbound { e.acl = acl.NewDefaultManager(e.firewall) @@ -1512,6 +1517,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV if e.rpManager != nil { peerConn.SetOnConnected(e.rpManager.OnConnected) peerConn.SetOnDisconnected(e.rpManager.OnDisconnected) + peerConn.SetRosenpassInitializedPresharedKeyValidator(e.rpManager.IsPresharedKeyInitialized) } return peerConn, nil diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 56829393c..af9f27a71 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -214,6 +214,10 @@ func (m *MockWGIface) LastActivities() map[string]monotime.Time { return nil } +func (m *MockWGIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error { + return nil +} + func TestMain(m *testing.M) { _ = util.InitLog("debug", util.LogConsole) code := m.Run() diff --git a/client/internal/iface_common.go b/client/internal/iface_common.go index 90b06cbd1..f8a433a6e 100644 --- a/client/internal/iface_common.go +++ b/client/internal/iface_common.go @@ -42,4 +42,5 @@ type wgIfaceBase interface { GetNet() *netstack.Net FullStats() (*configurer.Stats, error) LastActivities() map[string]monotime.Time + SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error } diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 80ca36789..ba82354a2 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -88,8 +88,9 @@ type Conn struct { relayManager *relayClient.Manager srWatcher *guard.SRWatcher - onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) - onDisconnected func(remotePeer string) + onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) + onDisconnected func(remotePeer string) + rosenpassInitializedPresharedKeyValidator func(peerKey string) bool statusRelay *worker.AtomicWorkerStatus statusICE *worker.AtomicWorkerStatus @@ -289,6 +290,13 @@ func (conn *Conn) SetOnDisconnected(handler func(remotePeer string)) { conn.onDisconnected = handler } +// SetRosenpassInitializedPresharedKeyValidator sets a function to check if Rosenpass has taken over +// PSK management for a peer. When this returns true, presharedKey() returns nil +// to prevent UpdatePeer from overwriting the Rosenpass-managed PSK. +func (conn *Conn) SetRosenpassInitializedPresharedKeyValidator(handler func(peerKey string) bool) { + conn.rosenpassInitializedPresharedKeyValidator = handler +} + func (conn *Conn) OnRemoteOffer(offer OfferAnswer) { conn.dumpState.RemoteOffer() conn.Log.Infof("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay) @@ -759,10 +767,24 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key { return conn.config.WgConfig.PreSharedKey } + // If Rosenpass has already set a PSK for this peer, return nil to prevent + // UpdatePeer from overwriting the Rosenpass-managed key. + if conn.rosenpassInitializedPresharedKeyValidator != nil && conn.rosenpassInitializedPresharedKeyValidator(conn.config.Key) { + return nil + } + + // Use NetBird PSK as the seed for Rosenpass. This same PSK is passed to + // Rosenpass as PeerConfig.PresharedKey, ensuring the derived post-quantum + // key is cryptographically bound to the original secret. + if conn.config.WgConfig.PreSharedKey != nil { + return conn.config.WgConfig.PreSharedKey + } + + // Fallback to deterministic key if no NetBird PSK is configured determKey, err := conn.rosenpassDetermKey() if err != nil { conn.Log.Errorf("failed to generate Rosenpass initial key: %v", err) - return conn.config.WgConfig.PreSharedKey + return nil } return determKey diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go index 6b47f95eb..32383b530 100644 --- a/client/internal/peer/conn_test.go +++ b/client/internal/peer/conn_test.go @@ -284,3 +284,27 @@ func TestConn_presharedKey(t *testing.T) { }) } } + +func TestConn_presharedKey_RosenpassManaged(t *testing.T) { + conn := Conn{ + config: ConnConfig{ + Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", + LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", + RosenpassConfig: RosenpassConfig{PubKey: []byte("dummykey")}, + }, + } + + // When Rosenpass has already initialized the PSK for this peer, + // presharedKey must return nil to avoid UpdatePeer overwriting it. + conn.rosenpassInitializedPresharedKeyValidator = func(peerKey string) bool { return true } + if k := conn.presharedKey([]byte("remote")); k != nil { + t.Fatalf("expected nil presharedKey when Rosenpass manages PSK, got %v", k) + } + + // When Rosenpass hasn't taken over yet, presharedKey should provide + // a non-nil initial key (deterministic or from NetBird PSK). + conn.rosenpassInitializedPresharedKeyValidator = func(peerKey string) bool { return false } + if k := conn.presharedKey([]byte("remote")); k == nil { + t.Fatalf("expected non-nil presharedKey before Rosenpass manages PSK") + } +} diff --git a/client/internal/rosenpass/manager.go b/client/internal/rosenpass/manager.go index d2d7408fd..26a1eef58 100644 --- a/client/internal/rosenpass/manager.go +++ b/client/internal/rosenpass/manager.go @@ -34,6 +34,7 @@ type Manager struct { server *rp.Server lock sync.Mutex port int + wgIface PresharedKeySetter } // NewManager creates a new Rosenpass manager @@ -109,7 +110,13 @@ func (m *Manager) generateConfig() (rp.Config, error) { cfg.SecretKey = m.ssk cfg.Peers = []rp.PeerConfig{} - m.rpWgHandler, _ = NewNetbirdHandler(m.preSharedKey, m.ifaceName) + + m.lock.Lock() + m.rpWgHandler = NewNetbirdHandler() + if m.wgIface != nil { + m.rpWgHandler.SetInterface(m.wgIface) + } + m.lock.Unlock() cfg.Handlers = []rp.Handler{m.rpWgHandler} @@ -172,6 +179,20 @@ func (m *Manager) Close() error { return nil } +// SetInterface sets the WireGuard interface for the rosenpass handler. +// This can be called before or after Run() - the interface will be stored +// and passed to the handler when it's created or updated immediately if +// already running. +func (m *Manager) SetInterface(iface PresharedKeySetter) { + m.lock.Lock() + defer m.lock.Unlock() + + m.wgIface = iface + if m.rpWgHandler != nil { + m.rpWgHandler.SetInterface(iface) + } +} + // OnConnected is a handler function that is triggered when a connection to a remote peer establishes func (m *Manager) OnConnected(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) { m.lock.Lock() @@ -192,6 +213,20 @@ func (m *Manager) OnConnected(remoteWireGuardKey string, remoteRosenpassPubKey [ } } +// IsPresharedKeyInitialized returns true if Rosenpass has completed a handshake +// and set a PSK for the given WireGuard peer. +func (m *Manager) IsPresharedKeyInitialized(wireGuardPubKey string) bool { + m.lock.Lock() + defer m.lock.Unlock() + + peerID, ok := m.rpPeerIDs[wireGuardPubKey] + if !ok || peerID == nil { + return false + } + + return m.rpWgHandler.IsPeerInitialized(*peerID) +} + func findRandomAvailableUDPPort() (int, error) { conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) if err != nil { diff --git a/client/internal/rosenpass/netbird_handler.go b/client/internal/rosenpass/netbird_handler.go index 345f95c01..9de2409ef 100644 --- a/client/internal/rosenpass/netbird_handler.go +++ b/client/internal/rosenpass/netbird_handler.go @@ -1,46 +1,50 @@ package rosenpass import ( - "fmt" - "log/slog" + "sync" rp "cunicu.li/go-rosenpass" log "github.com/sirupsen/logrus" - "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) +// PresharedKeySetter is the interface for setting preshared keys on WireGuard peers. +// This minimal interface allows rosenpass to update PSKs without depending on the full WGIface. +type PresharedKeySetter interface { + SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error +} + type wireGuardPeer struct { Interface string PublicKey rp.Key } type NetbirdHandler struct { - ifaceName string - client *wgctrl.Client - peers map[rp.PeerID]wireGuardPeer - presharedKey [32]byte + mu sync.Mutex + iface PresharedKeySetter + peers map[rp.PeerID]wireGuardPeer + initializedPeers map[rp.PeerID]bool } -func NewNetbirdHandler(preSharedKey *[32]byte, wgIfaceName string) (hdlr *NetbirdHandler, err error) { - hdlr = &NetbirdHandler{ - ifaceName: wgIfaceName, - peers: map[rp.PeerID]wireGuardPeer{}, +func NewNetbirdHandler() *NetbirdHandler { + return &NetbirdHandler{ + peers: map[rp.PeerID]wireGuardPeer{}, + initializedPeers: map[rp.PeerID]bool{}, } +} - if preSharedKey != nil { - hdlr.presharedKey = *preSharedKey - } - - if hdlr.client, err = wgctrl.New(); err != nil { - return nil, fmt.Errorf("failed to creat WireGuard client: %w", err) - } - - return hdlr, nil +// SetInterface sets the WireGuard interface for the handler. +// This must be called after the WireGuard interface is created. +func (h *NetbirdHandler) SetInterface(iface PresharedKeySetter) { + h.mu.Lock() + defer h.mu.Unlock() + h.iface = iface } func (h *NetbirdHandler) AddPeer(pid rp.PeerID, intf string, pk rp.Key) { + h.mu.Lock() + defer h.mu.Unlock() h.peers[pid] = wireGuardPeer{ Interface: intf, PublicKey: pk, @@ -48,79 +52,61 @@ func (h *NetbirdHandler) AddPeer(pid rp.PeerID, intf string, pk rp.Key) { } func (h *NetbirdHandler) RemovePeer(pid rp.PeerID) { + h.mu.Lock() + defer h.mu.Unlock() delete(h.peers, pid) + delete(h.initializedPeers, pid) +} + +// IsPeerInitialized returns true if Rosenpass has completed a handshake +// and set a PSK for this peer. +func (h *NetbirdHandler) IsPeerInitialized(pid rp.PeerID) bool { + h.mu.Lock() + defer h.mu.Unlock() + return h.initializedPeers[pid] } func (h *NetbirdHandler) HandshakeCompleted(pid rp.PeerID, key rp.Key) { - log.Debug("Handshake complete") h.outputKey(rp.KeyOutputReasonStale, pid, key) } func (h *NetbirdHandler) HandshakeExpired(pid rp.PeerID) { key, _ := rp.GeneratePresharedKey() - log.Debug("Handshake expired") h.outputKey(rp.KeyOutputReasonStale, pid, key) } func (h *NetbirdHandler) outputKey(_ rp.KeyOutputReason, pid rp.PeerID, psk rp.Key) { + h.mu.Lock() + iface := h.iface wg, ok := h.peers[pid] + isInitialized := h.initializedPeers[pid] + h.mu.Unlock() + + if iface == nil { + log.Warn("rosenpass: interface not set, cannot update preshared key") + return + } + if !ok { return } - device, err := h.client.Device(h.ifaceName) - if err != nil { - log.Errorf("Failed to get WireGuard device: %v", err) + peerKey := wgtypes.Key(wg.PublicKey).String() + pskKey := wgtypes.Key(psk) + + // Use updateOnly=true for later rotations (peer already has Rosenpass PSK) + // Use updateOnly=false for first rotation (peer has original/empty PSK) + if err := iface.SetPresharedKey(peerKey, pskKey, isInitialized); err != nil { + log.Errorf("Failed to apply rosenpass key: %v", err) return } - config := []wgtypes.PeerConfig{ - { - UpdateOnly: true, - PublicKey: wgtypes.Key(wg.PublicKey), - PresharedKey: (*wgtypes.Key)(&psk), - }, - } - for _, peer := range device.Peers { - if peer.PublicKey == wgtypes.Key(wg.PublicKey) { - if publicKeyEmpty(peer.PresharedKey) || peer.PresharedKey == h.presharedKey { - log.Debugf("Restart wireguard connection to peer %s", peer.PublicKey) - config = []wgtypes.PeerConfig{ - { - PublicKey: wgtypes.Key(wg.PublicKey), - PresharedKey: (*wgtypes.Key)(&psk), - Endpoint: peer.Endpoint, - AllowedIPs: peer.AllowedIPs, - }, - } - err = h.client.ConfigureDevice(wg.Interface, wgtypes.Config{ - Peers: []wgtypes.PeerConfig{ - { - Remove: true, - PublicKey: wgtypes.Key(wg.PublicKey), - }, - }, - }) - if err != nil { - slog.Debug("Failed to remove peer") - return - } - } + // Mark peer as isInitialized after the successful first rotation + if !isInitialized { + h.mu.Lock() + if _, exists := h.peers[pid]; exists { + h.initializedPeers[pid] = true } - } - - if err = h.client.ConfigureDevice(wg.Interface, wgtypes.Config{ - Peers: config, - }); err != nil { - log.Errorf("Failed to apply rosenpass key: %v", err) + h.mu.Unlock() } } - -func publicKeyEmpty(key wgtypes.Key) bool { - for _, b := range key { - if b != 0 { - return false - } - } - return true -} diff --git a/client/server/server.go b/client/server/server.go index 22e80ab25..408bd56db 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -66,7 +66,7 @@ type Server struct { proto.UnimplementedDaemonServiceServer clientRunning bool // protected by mutex clientRunningChan chan struct{} - clientGiveUpChan chan struct{} + clientGiveUpChan chan struct{} // closed when connectWithRetryRuns goroutine exits connectClient *internal.ConnectClient @@ -792,9 +792,11 @@ func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfi // Down engine work in the daemon. func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) { s.mutex.Lock() - defer s.mutex.Unlock() + + giveUpChan := s.clientGiveUpChan if err := s.cleanupConnection(); err != nil { + s.mutex.Unlock() // todo review to update the status in case any type of error log.Errorf("failed to shut down properly: %v", err) return nil, err @@ -803,6 +805,20 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes state := internal.CtxGetState(s.rootCtx) state.Set(internal.StatusIdle) + s.mutex.Unlock() + + // Wait for the connectWithRetryRuns goroutine to finish with a short timeout. + // This prevents the goroutine from setting ErrResetConnection after Down() returns. + // The giveUpChan is closed at the end of connectWithRetryRuns. + if giveUpChan != nil { + select { + case <-giveUpChan: + log.Debugf("client goroutine finished successfully") + case <-time.After(5 * time.Second): + log.Warnf("timeout waiting for client goroutine to finish, proceeding anyway") + } + } + return &proto.DownResponse{}, nil } diff --git a/idp/dex/connector.go b/idp/dex/connector.go new file mode 100644 index 000000000..cad682141 --- /dev/null +++ b/idp/dex/connector.go @@ -0,0 +1,356 @@ +// Package dex provides an embedded Dex OIDC identity provider. +package dex + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + + "github.com/dexidp/dex/storage" +) + +// ConnectorConfig represents the configuration for an identity provider connector +type ConnectorConfig struct { + // ID is the unique identifier for the connector + ID string + // Name is a human-readable name for the connector + Name string + // Type is the connector type (oidc, google, microsoft) + Type string + // Issuer is the OIDC issuer URL (for OIDC-based connectors) + Issuer string + // ClientID is the OAuth2 client ID + ClientID string + // ClientSecret is the OAuth2 client secret + ClientSecret string + // RedirectURI is the OAuth2 redirect URI + RedirectURI string +} + +// CreateConnector creates a new connector in Dex storage. +// It maps the connector config to the appropriate Dex connector type and configuration. +func (p *Provider) CreateConnector(ctx context.Context, cfg *ConnectorConfig) (*ConnectorConfig, error) { + // Fill in the redirect URI if not provided + if cfg.RedirectURI == "" { + cfg.RedirectURI = p.GetRedirectURI() + } + + storageConn, err := p.buildStorageConnector(cfg) + if err != nil { + return nil, fmt.Errorf("failed to build connector: %w", err) + } + + if err := p.storage.CreateConnector(ctx, storageConn); err != nil { + return nil, fmt.Errorf("failed to create connector: %w", err) + } + + p.logger.Info("connector created", "id", cfg.ID, "type", cfg.Type) + return cfg, nil +} + +// GetConnector retrieves a connector by ID from Dex storage. +func (p *Provider) GetConnector(ctx context.Context, id string) (*ConnectorConfig, error) { + conn, err := p.storage.GetConnector(ctx, id) + if err != nil { + if err == storage.ErrNotFound { + return nil, err + } + return nil, fmt.Errorf("failed to get connector: %w", err) + } + + return p.parseStorageConnector(conn) +} + +// ListConnectors returns all connectors from Dex storage (excluding the local connector). +func (p *Provider) ListConnectors(ctx context.Context) ([]*ConnectorConfig, error) { + connectors, err := p.storage.ListConnectors(ctx) + if err != nil { + return nil, fmt.Errorf("failed to list connectors: %w", err) + } + + result := make([]*ConnectorConfig, 0, len(connectors)) + for _, conn := range connectors { + // Skip the local password connector + if conn.ID == "local" && conn.Type == "local" { + continue + } + + cfg, err := p.parseStorageConnector(conn) + if err != nil { + p.logger.Warn("failed to parse connector", "id", conn.ID, "error", err) + continue + } + result = append(result, cfg) + } + + return result, nil +} + +// UpdateConnector updates an existing connector in Dex storage. +// It merges incoming updates with existing values to prevent data loss on partial updates. +func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) error { + if err := p.storage.UpdateConnector(ctx, cfg.ID, func(old storage.Connector) (storage.Connector, error) { + oldCfg, err := p.parseStorageConnector(old) + if err != nil { + return storage.Connector{}, fmt.Errorf("failed to parse existing connector: %w", err) + } + + mergeConnectorConfig(cfg, oldCfg) + + storageConn, err := p.buildStorageConnector(cfg) + if err != nil { + return storage.Connector{}, fmt.Errorf("failed to build connector: %w", err) + } + return storageConn, nil + }); err != nil { + return fmt.Errorf("failed to update connector: %w", err) + } + + p.logger.Info("connector updated", "id", cfg.ID, "type", cfg.Type) + return nil +} + +// mergeConnectorConfig preserves existing values for empty fields in the update. +func mergeConnectorConfig(cfg, oldCfg *ConnectorConfig) { + if cfg.ClientSecret == "" { + cfg.ClientSecret = oldCfg.ClientSecret + } + if cfg.RedirectURI == "" { + cfg.RedirectURI = oldCfg.RedirectURI + } + if cfg.Issuer == "" && cfg.Type == oldCfg.Type { + cfg.Issuer = oldCfg.Issuer + } + if cfg.ClientID == "" { + cfg.ClientID = oldCfg.ClientID + } + if cfg.Name == "" { + cfg.Name = oldCfg.Name + } +} + +// DeleteConnector removes a connector from Dex storage. +func (p *Provider) DeleteConnector(ctx context.Context, id string) error { + // Prevent deletion of the local connector + if id == "local" { + return fmt.Errorf("cannot delete the local password connector") + } + + if err := p.storage.DeleteConnector(ctx, id); err != nil { + return fmt.Errorf("failed to delete connector: %w", err) + } + + p.logger.Info("connector deleted", "id", id) + return nil +} + +// GetRedirectURI returns the default redirect URI for connectors. +func (p *Provider) GetRedirectURI() string { + if p.config == nil { + return "" + } + issuer := strings.TrimSuffix(p.config.Issuer, "/") + if !strings.HasSuffix(issuer, "/oauth2") { + issuer += "/oauth2" + } + return issuer + "/callback" +} + +// buildStorageConnector creates a storage.Connector from ConnectorConfig. +// It handles the type-specific configuration for each connector type. +func (p *Provider) buildStorageConnector(cfg *ConnectorConfig) (storage.Connector, error) { + redirectURI := p.resolveRedirectURI(cfg.RedirectURI) + + var dexType string + var configData []byte + var err error + + switch cfg.Type { + case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak": + dexType = "oidc" + configData, err = buildOIDCConnectorConfig(cfg, redirectURI) + case "google": + dexType = "google" + configData, err = buildOAuth2ConnectorConfig(cfg, redirectURI) + case "microsoft": + dexType = "microsoft" + configData, err = buildOAuth2ConnectorConfig(cfg, redirectURI) + default: + return storage.Connector{}, fmt.Errorf("unsupported connector type: %s", cfg.Type) + } + if err != nil { + return storage.Connector{}, err + } + + return storage.Connector{ID: cfg.ID, Type: dexType, Name: cfg.Name, Config: configData}, nil +} + +// resolveRedirectURI returns the redirect URI, using a default if not provided +func (p *Provider) resolveRedirectURI(redirectURI string) string { + if redirectURI != "" || p.config == nil { + return redirectURI + } + issuer := strings.TrimSuffix(p.config.Issuer, "/") + if !strings.HasSuffix(issuer, "/oauth2") { + issuer += "/oauth2" + } + return issuer + "/callback" +} + +// buildOIDCConnectorConfig creates config for OIDC-based connectors +func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte, error) { + oidcConfig := map[string]interface{}{ + "issuer": cfg.Issuer, + "clientID": cfg.ClientID, + "clientSecret": cfg.ClientSecret, + "redirectURI": redirectURI, + "scopes": []string{"openid", "profile", "email"}, + "insecureEnableGroups": true, + //some providers don't return email verified, so we need to skip it if not present (e.g., Entra, Okta, Duo) + "insecureSkipEmailVerified": true, + } + switch cfg.Type { + case "zitadel": + oidcConfig["getUserInfo"] = true + case "entra": + oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"} + case "okta": + oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"} + case "pocketid": + oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"} + } + return encodeConnectorConfig(oidcConfig) +} + +// buildOAuth2ConnectorConfig creates config for OAuth2 connectors (google, microsoft) +func buildOAuth2ConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte, error) { + return encodeConnectorConfig(map[string]interface{}{ + "clientID": cfg.ClientID, + "clientSecret": cfg.ClientSecret, + "redirectURI": redirectURI, + }) +} + +// parseStorageConnector converts a storage.Connector back to ConnectorConfig. +// It infers the original identity provider type from the Dex connector type and ID. +func (p *Provider) parseStorageConnector(conn storage.Connector) (*ConnectorConfig, error) { + cfg := &ConnectorConfig{ + ID: conn.ID, + Name: conn.Name, + } + + if len(conn.Config) == 0 { + cfg.Type = conn.Type + return cfg, nil + } + + var configMap map[string]interface{} + if err := decodeConnectorConfig(conn.Config, &configMap); err != nil { + return nil, fmt.Errorf("failed to parse connector config: %w", err) + } + + // Extract common fields + if v, ok := configMap["clientID"].(string); ok { + cfg.ClientID = v + } + if v, ok := configMap["clientSecret"].(string); ok { + cfg.ClientSecret = v + } + if v, ok := configMap["redirectURI"].(string); ok { + cfg.RedirectURI = v + } + if v, ok := configMap["issuer"].(string); ok { + cfg.Issuer = v + } + + // Infer the original identity provider type from Dex connector type and ID + cfg.Type = inferIdentityProviderType(conn.Type, conn.ID, configMap) + + return cfg, nil +} + +// inferIdentityProviderType determines the original identity provider type +// based on the Dex connector type, connector ID, and configuration. +func inferIdentityProviderType(dexType, connectorID string, _ map[string]interface{}) string { + if dexType != "oidc" { + return dexType + } + return inferOIDCProviderType(connectorID) +} + +// inferOIDCProviderType infers the specific OIDC provider from connector ID +func inferOIDCProviderType(connectorID string) string { + connectorIDLower := strings.ToLower(connectorID) + for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak"} { + if strings.Contains(connectorIDLower, provider) { + return provider + } + } + return "oidc" +} + +// encodeConnectorConfig serializes connector config to JSON bytes. +func encodeConnectorConfig(config map[string]interface{}) ([]byte, error) { + return json.Marshal(config) +} + +// decodeConnectorConfig deserializes connector config from JSON bytes. +func decodeConnectorConfig(data []byte, v interface{}) error { + return json.Unmarshal(data, v) +} + +// ensureLocalConnector creates a local (password) connector if it doesn't exist +func ensureLocalConnector(ctx context.Context, stor storage.Storage) error { + // Check specifically for the local connector + _, err := stor.GetConnector(ctx, "local") + if err == nil { + // Local connector already exists + return nil + } + if !errors.Is(err, storage.ErrNotFound) { + return fmt.Errorf("failed to get local connector: %w", err) + } + + // Create a local connector for password authentication + localConnector := storage.Connector{ + ID: "local", + Type: "local", + Name: "Email", + } + + if err := stor.CreateConnector(ctx, localConnector); err != nil { + return fmt.Errorf("failed to create local connector: %w", err) + } + + return nil +} + +// ensureStaticConnectors creates or updates static connectors in storage +func ensureStaticConnectors(ctx context.Context, stor storage.Storage, connectors []Connector) error { + for _, conn := range connectors { + storConn, err := conn.ToStorageConnector() + if err != nil { + return fmt.Errorf("failed to convert connector %s: %w", conn.ID, err) + } + _, err = stor.GetConnector(ctx, conn.ID) + if err == storage.ErrNotFound { + if err := stor.CreateConnector(ctx, storConn); err != nil { + return fmt.Errorf("failed to create connector %s: %w", conn.ID, err) + } + continue + } + if err != nil { + return fmt.Errorf("failed to get connector %s: %w", conn.ID, err) + } + if err := stor.UpdateConnector(ctx, conn.ID, func(old storage.Connector) (storage.Connector, error) { + old.Name = storConn.Name + old.Config = storConn.Config + return old, nil + }); err != nil { + return fmt.Errorf("failed to update connector %s: %w", conn.ID, err) + } + } + return nil +} diff --git a/idp/dex/provider.go b/idp/dex/provider.go index 6625d9eaf..6c608dbf5 100644 --- a/idp/dex/provider.go +++ b/idp/dex/provider.go @@ -4,7 +4,6 @@ package dex import ( "context" "encoding/base64" - "encoding/json" "errors" "fmt" "log/slog" @@ -245,34 +244,6 @@ func ensureStaticClients(ctx context.Context, stor storage.Storage, clients []st return nil } -// ensureStaticConnectors creates or updates static connectors in storage -func ensureStaticConnectors(ctx context.Context, stor storage.Storage, connectors []Connector) error { - for _, conn := range connectors { - storConn, err := conn.ToStorageConnector() - if err != nil { - return fmt.Errorf("failed to convert connector %s: %w", conn.ID, err) - } - _, err = stor.GetConnector(ctx, conn.ID) - if errors.Is(err, storage.ErrNotFound) { - if err := stor.CreateConnector(ctx, storConn); err != nil { - return fmt.Errorf("failed to create connector %s: %w", conn.ID, err) - } - continue - } - if err != nil { - return fmt.Errorf("failed to get connector %s: %w", conn.ID, err) - } - if err := stor.UpdateConnector(ctx, conn.ID, func(old storage.Connector) (storage.Connector, error) { - old.Name = storConn.Name - old.Config = storConn.Config - return old, nil - }); err != nil { - return fmt.Errorf("failed to update connector %s: %w", conn.ID, err) - } - } - return nil -} - // buildDexConfig creates a server.Config with defaults applied func buildDexConfig(yamlConfig *YAMLConfig, stor storage.Storage, logger *slog.Logger) server.Config { cfg := yamlConfig.ToServerConfig(stor, logger) @@ -613,294 +584,37 @@ func (p *Provider) ListUsers(ctx context.Context) ([]storage.Password, error) { return p.storage.ListPasswords(ctx) } -// ensureLocalConnector creates a local (password) connector if none exists -func ensureLocalConnector(ctx context.Context, stor storage.Storage) error { - connectors, err := stor.ListConnectors(ctx) +// UpdateUserPassword updates the password for a user identified by userID. +// The userID can be either an encoded Dex ID (base64 protobuf) or a raw UUID. +// It verifies the current password before updating. +func (p *Provider) UpdateUserPassword(ctx context.Context, userID string, oldPassword, newPassword string) error { + // Get the user by ID to find their email + user, err := p.GetUserByID(ctx, userID) if err != nil { - return fmt.Errorf("failed to list connectors: %w", err) + return fmt.Errorf("failed to get user: %w", err) } - // If any connector exists, we're good - if len(connectors) > 0 { - return nil + // Verify old password + if err := bcrypt.CompareHashAndPassword(user.Hash, []byte(oldPassword)); err != nil { + return fmt.Errorf("current password is incorrect") } - // Create a local connector for password authentication - localConnector := storage.Connector{ - ID: "local", - Type: "local", - Name: "Email", - } - - if err := stor.CreateConnector(ctx, localConnector); err != nil { - return fmt.Errorf("failed to create local connector: %w", err) - } - - return nil -} - -// ConnectorConfig represents the configuration for an identity provider connector -type ConnectorConfig struct { - // ID is the unique identifier for the connector - ID string - // Name is a human-readable name for the connector - Name string - // Type is the connector type (oidc, google, microsoft) - Type string - // Issuer is the OIDC issuer URL (for OIDC-based connectors) - Issuer string - // ClientID is the OAuth2 client ID - ClientID string - // ClientSecret is the OAuth2 client secret - ClientSecret string - // RedirectURI is the OAuth2 redirect URI - RedirectURI string -} - -// CreateConnector creates a new connector in Dex storage. -// It maps the connector config to the appropriate Dex connector type and configuration. -func (p *Provider) CreateConnector(ctx context.Context, cfg *ConnectorConfig) (*ConnectorConfig, error) { - // Fill in the redirect URI if not provided - if cfg.RedirectURI == "" { - cfg.RedirectURI = p.GetRedirectURI() - } - - storageConn, err := p.buildStorageConnector(cfg) + // Hash the new password + newHash, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost) if err != nil { - return nil, fmt.Errorf("failed to build connector: %w", err) + return fmt.Errorf("failed to hash new password: %w", err) } - if err := p.storage.CreateConnector(ctx, storageConn); err != nil { - return nil, fmt.Errorf("failed to create connector: %w", err) - } - - p.logger.Info("connector created", "id", cfg.ID, "type", cfg.Type) - return cfg, nil -} - -// GetConnector retrieves a connector by ID from Dex storage. -func (p *Provider) GetConnector(ctx context.Context, id string) (*ConnectorConfig, error) { - conn, err := p.storage.GetConnector(ctx, id) - if err != nil { - if err == storage.ErrNotFound { - return nil, err - } - return nil, fmt.Errorf("failed to get connector: %w", err) - } - - return p.parseStorageConnector(conn) -} - -// ListConnectors returns all connectors from Dex storage (excluding the local connector). -func (p *Provider) ListConnectors(ctx context.Context) ([]*ConnectorConfig, error) { - connectors, err := p.storage.ListConnectors(ctx) - if err != nil { - return nil, fmt.Errorf("failed to list connectors: %w", err) - } - - result := make([]*ConnectorConfig, 0, len(connectors)) - for _, conn := range connectors { - // Skip the local password connector - if conn.ID == "local" && conn.Type == "local" { - continue - } - - cfg, err := p.parseStorageConnector(conn) - if err != nil { - p.logger.Warn("failed to parse connector", "id", conn.ID, "error", err) - continue - } - result = append(result, cfg) - } - - return result, nil -} - -// UpdateConnector updates an existing connector in Dex storage. -func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) error { - storageConn, err := p.buildStorageConnector(cfg) - if err != nil { - return fmt.Errorf("failed to build connector: %w", err) - } - - if err := p.storage.UpdateConnector(ctx, cfg.ID, func(old storage.Connector) (storage.Connector, error) { - return storageConn, nil - }); err != nil { - return fmt.Errorf("failed to update connector: %w", err) - } - - p.logger.Info("connector updated", "id", cfg.ID, "type", cfg.Type) - return nil -} - -// DeleteConnector removes a connector from Dex storage. -func (p *Provider) DeleteConnector(ctx context.Context, id string) error { - // Prevent deletion of the local connector - if id == "local" { - return fmt.Errorf("cannot delete the local password connector") - } - - if err := p.storage.DeleteConnector(ctx, id); err != nil { - return fmt.Errorf("failed to delete connector: %w", err) - } - - p.logger.Info("connector deleted", "id", id) - return nil -} - -// buildStorageConnector creates a storage.Connector from ConnectorConfig. -// It handles the type-specific configuration for each connector type. -func (p *Provider) buildStorageConnector(cfg *ConnectorConfig) (storage.Connector, error) { - redirectURI := p.resolveRedirectURI(cfg.RedirectURI) - - var dexType string - var configData []byte - var err error - - switch cfg.Type { - case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak": - dexType = "oidc" - configData, err = buildOIDCConnectorConfig(cfg, redirectURI) - case "google": - dexType = "google" - configData, err = buildOAuth2ConnectorConfig(cfg, redirectURI) - case "microsoft": - dexType = "microsoft" - configData, err = buildOAuth2ConnectorConfig(cfg, redirectURI) - default: - return storage.Connector{}, fmt.Errorf("unsupported connector type: %s", cfg.Type) - } - if err != nil { - return storage.Connector{}, err - } - - return storage.Connector{ID: cfg.ID, Type: dexType, Name: cfg.Name, Config: configData}, nil -} - -// resolveRedirectURI returns the redirect URI, using a default if not provided -func (p *Provider) resolveRedirectURI(redirectURI string) string { - if redirectURI != "" || p.config == nil { - return redirectURI - } - issuer := strings.TrimSuffix(p.config.Issuer, "/") - if !strings.HasSuffix(issuer, "/oauth2") { - issuer += "/oauth2" - } - return issuer + "/callback" -} - -// buildOIDCConnectorConfig creates config for OIDC-based connectors -func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte, error) { - oidcConfig := map[string]interface{}{ - "issuer": cfg.Issuer, - "clientID": cfg.ClientID, - "clientSecret": cfg.ClientSecret, - "redirectURI": redirectURI, - "scopes": []string{"openid", "profile", "email"}, - "insecureEnableGroups": true, - //some providers don't return email verified, so we need to skip it if not present (e.g., Entra, Okta, Duo) - "insecureSkipEmailVerified": true, - } - switch cfg.Type { - case "zitadel": - oidcConfig["getUserInfo"] = true - case "entra": - oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"} - case "okta": - oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"} - case "pocketid": - oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"} - } - return encodeConnectorConfig(oidcConfig) -} - -// buildOAuth2ConnectorConfig creates config for OAuth2 connectors (google, microsoft) -func buildOAuth2ConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte, error) { - return encodeConnectorConfig(map[string]interface{}{ - "clientID": cfg.ClientID, - "clientSecret": cfg.ClientSecret, - "redirectURI": redirectURI, + // Update the password in storage + err = p.storage.UpdatePassword(ctx, user.Email, func(old storage.Password) (storage.Password, error) { + old.Hash = newHash + return old, nil }) -} - -// parseStorageConnector converts a storage.Connector back to ConnectorConfig. -// It infers the original identity provider type from the Dex connector type and ID. -func (p *Provider) parseStorageConnector(conn storage.Connector) (*ConnectorConfig, error) { - cfg := &ConnectorConfig{ - ID: conn.ID, - Name: conn.Name, + if err != nil { + return fmt.Errorf("failed to update password: %w", err) } - if len(conn.Config) == 0 { - cfg.Type = conn.Type - return cfg, nil - } - - var configMap map[string]interface{} - if err := decodeConnectorConfig(conn.Config, &configMap); err != nil { - return nil, fmt.Errorf("failed to parse connector config: %w", err) - } - - // Extract common fields - if v, ok := configMap["clientID"].(string); ok { - cfg.ClientID = v - } - if v, ok := configMap["clientSecret"].(string); ok { - cfg.ClientSecret = v - } - if v, ok := configMap["redirectURI"].(string); ok { - cfg.RedirectURI = v - } - if v, ok := configMap["issuer"].(string); ok { - cfg.Issuer = v - } - - // Infer the original identity provider type from Dex connector type and ID - cfg.Type = inferIdentityProviderType(conn.Type, conn.ID, configMap) - - return cfg, nil -} - -// inferIdentityProviderType determines the original identity provider type -// based on the Dex connector type, connector ID, and configuration. -func inferIdentityProviderType(dexType, connectorID string, _ map[string]interface{}) string { - if dexType != "oidc" { - return dexType - } - return inferOIDCProviderType(connectorID) -} - -// inferOIDCProviderType infers the specific OIDC provider from connector ID -func inferOIDCProviderType(connectorID string) string { - connectorIDLower := strings.ToLower(connectorID) - for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak"} { - if strings.Contains(connectorIDLower, provider) { - return provider - } - } - return "oidc" -} - -// encodeConnectorConfig serializes connector config to JSON bytes. -func encodeConnectorConfig(config map[string]interface{}) ([]byte, error) { - return json.Marshal(config) -} - -// decodeConnectorConfig deserializes connector config from JSON bytes. -func decodeConnectorConfig(data []byte, v interface{}) error { - return json.Unmarshal(data, v) -} - -// GetRedirectURI returns the default redirect URI for connectors. -func (p *Provider) GetRedirectURI() string { - if p.config == nil { - return "" - } - issuer := strings.TrimSuffix(p.config.Issuer, "/") - if !strings.HasSuffix(issuer, "/oauth2") { - issuer += "/oauth2" - } - return issuer + "/callback" + return nil } // GetIssuer returns the OIDC issuer URL. diff --git a/infrastructure_files/getting-started.sh b/infrastructure_files/getting-started.sh index 8676840a6..25599997c 100755 --- a/infrastructure_files/getting-started.sh +++ b/infrastructure_files/getting-started.sh @@ -82,16 +82,6 @@ read_nb_domain() { return 0 } -get_turn_external_ip() { - TURN_EXTERNAL_IP_CONFIG="#external-ip=" - IP=$(curl -s -4 https://jsonip.com | jq -r '.ip') - if [[ "x-$IP" != "x-" ]]; then - TURN_EXTERNAL_IP_CONFIG="external-ip=$IP" - fi - echo "$TURN_EXTERNAL_IP_CONFIG" - return 0 -} - read_reverse_proxy_type() { echo "" > /dev/stderr echo "Which reverse proxy will you use?" > /dev/stderr @@ -249,14 +239,17 @@ initialize_default_values() { NETBIRD_PORT=80 NETBIRD_HTTP_PROTOCOL="http" NETBIRD_RELAY_PROTO="rel" - TURN_USER="self" - TURN_PASSWORD=$(openssl rand -base64 32 | sed "$SED_STRIP_PADDING") NETBIRD_RELAY_AUTH_SECRET=$(openssl rand -base64 32 | sed "$SED_STRIP_PADDING") # Note: DataStoreEncryptionKey must keep base64 padding (=) for Go's base64.StdEncoding DATASTORE_ENCRYPTION_KEY=$(openssl rand -base64 32) - TURN_MIN_PORT=49152 - TURN_MAX_PORT=65535 - TURN_EXTERNAL_IP_CONFIG=$(get_turn_external_ip) + NETBIRD_STUN_PORT=3478 + + # Docker images + CADDY_IMAGE="caddy" + DASHBOARD_IMAGE="netbirdio/dashboard:latest" + SIGNAL_IMAGE="netbirdio/signal:latest" + RELAY_IMAGE="netbirdio/relay:latest" + MANAGEMENT_IMAGE="netbirdio/management:latest" # Reverse proxy configuration REVERSE_PROXY_TYPE="0" @@ -320,7 +313,7 @@ check_existing_installation() { echo "Generated files already exist, if you want to reinitialize the environment, please remove them first." echo "You can use the following commands:" echo " $DOCKER_COMPOSE_COMMAND down --volumes # to remove all containers and volumes" - echo " rm -f docker-compose.yml Caddyfile dashboard.env turnserver.conf management.json relay.env nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt" + echo " rm -f docker-compose.yml Caddyfile dashboard.env management.json relay.env nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt" echo "Be aware that this will remove all data from the database, and you will have to reconfigure the dashboard." exit 1 fi @@ -363,7 +356,6 @@ generate_configuration_files() { # Common files for all configurations render_dashboard_env > dashboard.env render_management_json > management.json - render_turn_server_conf > turnserver.conf render_relay_env > relay.env return 0 } @@ -487,34 +479,13 @@ EOF return 0 } -render_turn_server_conf() { - cat <