From 4a25a0d413f10c6cfe75e3b65f6a6dafdd8e0c56 Mon Sep 17 00:00:00 2001 From: Owen Date: Sat, 31 Jan 2026 16:58:05 -0800 Subject: [PATCH 1/4] Dont go unregistered when low power mode Former-commit-id: 0938564038c6e50f9b0bba166f79f9f2ab3a366a --- olm/connect.go | 4 ++-- olm/data.go | 2 +- olm/olm.go | 11 ++++------- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/olm/connect.go b/olm/connect.go index 90ad567..dc05d1f 100644 --- a/olm/connect.go +++ b/olm/connect.go @@ -36,7 +36,7 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) { var wgData WgData - if o.connected { + if o.registered { logger.Info("Already connected. Ignoring new connection request.") return } @@ -208,7 +208,7 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) { o.apiServer.SetRegistered(true) - o.connected = true + o.registered = true // Start ping monitor now that we are registered and connected o.websocket.StartPingMonitor() diff --git a/olm/data.go b/olm/data.go index 050a23f..8bd0997 100644 --- a/olm/data.go +++ b/olm/data.go @@ -157,7 +157,7 @@ func (o *Olm) handleWgPeerUpdateData(msg websocket.WSMessage) { func (o *Olm) handleSync(msg websocket.WSMessage) { logger.Debug("Received sync message: %v", msg.Data) - if !o.connected { + if !o.registered { logger.Warn("Not connected, ignoring sync request") return } diff --git a/olm/olm.go b/olm/olm.go index e3a9d77..0625a63 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -32,7 +32,7 @@ type Olm struct { privateKey wgtypes.Key logFile *os.File - connected bool + registered bool tunnelRunning bool uapiListener net.Listener @@ -386,10 +386,10 @@ func (o *Olm) StartTunnel(config TunnelConfig) { o.apiServer.SetConnectionStatus(true) - if o.connected { + if o.registered { o.websocket.StartPingMonitor() - logger.Debug("Already connected, skipping registration") + logger.Debug("Already registered, skipping registration") return nil } @@ -615,7 +615,7 @@ func (o *Olm) StopTunnel() error { } // Reset the running state BEFORE cleanup to prevent callbacks from accessing nil pointers - o.connected = false + o.registered = false o.tunnelRunning = false // Cancel the tunnel context if it exists @@ -739,9 +739,6 @@ func (o *Olm) SetPowerMode(mode string) error { logger.Info("Switching to low power mode") - // Mark as disconnected so we re-register on reconnect - o.connected = false - // Update API server connection status if o.apiServer != nil { o.apiServer.SetConnectionStatus(false) From 1be5e454baff3185202d5bf31326a8ac13353688 Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 2 Feb 2026 10:03:22 -0800 Subject: [PATCH 2/4] Default override dns to true Ref #59 --- config.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config.go b/config.go index 910dc4e..5959270 100644 --- a/config.go +++ b/config.go @@ -89,7 +89,7 @@ func DefaultConfig() *OlmConfig { PingInterval: "3s", PingTimeout: "5s", DisableHolepunch: false, - OverrideDNS: false, + OverrideDNS: true, TunnelDNS: false, // DoNotCreateNewClient: false, sources: make(map[string]string), From dd9bff9a4b0d845b3e69f756ea50ae9ce44c712a Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 2 Feb 2026 18:03:29 -0800 Subject: [PATCH 3/4] Fix peer names clearing --- api/api.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/api/api.go b/api/api.go index efd3346..047ce08 100644 --- a/api/api.go +++ b/api/api.go @@ -272,9 +272,6 @@ func (s *API) SetConnectionStatus(isConnected bool) { if isConnected { s.connectedAt = time.Now() - } else { - // Clear peer statuses when disconnected - s.peerStatuses = make(map[int]*PeerStatus) } } From af973b244064ea62aded51f9c7b60919b6f8c8d2 Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 6 Feb 2026 15:17:01 -0800 Subject: [PATCH 4/4] Support prt records --- dns/dns_proxy.go | 30 ++- dns/dns_records.go | 174 +++++++++++++++- dns/dns_records_test.go | 449 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 651 insertions(+), 2 deletions(-) diff --git a/dns/dns_proxy.go b/dns/dns_proxy.go index f65e923..986e847 100644 --- a/dns/dns_proxy.go +++ b/dns/dns_proxy.go @@ -380,7 +380,7 @@ func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clie // Check if we have local records for this query var response *dns.Msg - if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA { + if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA || question.Qtype == dns.TypePTR { response = p.checkLocalRecords(msg, question) } @@ -410,6 +410,34 @@ func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clie // checkLocalRecords checks if we have local records for the query func (p *DNSProxy) checkLocalRecords(query *dns.Msg, question dns.Question) *dns.Msg { + // Handle PTR queries + if question.Qtype == dns.TypePTR { + if ptrDomain, ok := p.recordStore.GetPTRRecord(question.Name); ok { + logger.Debug("Found local PTR record for %s -> %s", question.Name, ptrDomain) + + // Create response message + response := new(dns.Msg) + response.SetReply(query) + response.Authoritative = true + + // Add PTR answer record + rr := &dns.PTR{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypePTR, + Class: dns.ClassINET, + Ttl: 300, // 5 minutes + }, + Ptr: ptrDomain, + } + response.Answer = append(response.Answer, rr) + + return response + } + return nil + } + + // Handle A and AAAA queries var recordType RecordType if question.Qtype == dns.TypeA { recordType = RecordTypeA diff --git a/dns/dns_records.go b/dns/dns_records.go index cef0ad4..199b94b 100644 --- a/dns/dns_records.go +++ b/dns/dns_records.go @@ -1,6 +1,7 @@ package dns import ( + "fmt" "net" "strings" "sync" @@ -14,15 +15,17 @@ type RecordType uint16 const ( RecordTypeA RecordType = RecordType(dns.TypeA) RecordTypeAAAA RecordType = RecordType(dns.TypeAAAA) + RecordTypePTR RecordType = RecordType(dns.TypePTR) ) -// DNSRecordStore manages local DNS records for A and AAAA queries +// DNSRecordStore manages local DNS records for A, AAAA, and PTR queries type DNSRecordStore struct { mu sync.RWMutex aRecords map[string][]net.IP // domain -> list of IPv4 addresses aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses aWildcards map[string][]net.IP // wildcard pattern -> list of IPv4 addresses aaaaWildcards map[string][]net.IP // wildcard pattern -> list of IPv6 addresses + ptrRecords map[string]string // IP address string -> domain name } // NewDNSRecordStore creates a new DNS record store @@ -32,6 +35,7 @@ func NewDNSRecordStore() *DNSRecordStore { aaaaRecords: make(map[string][]net.IP), aWildcards: make(map[string][]net.IP), aaaaWildcards: make(map[string][]net.IP), + ptrRecords: make(map[string]string), } } @@ -39,6 +43,7 @@ func NewDNSRecordStore() *DNSRecordStore { // domain should be in FQDN format (e.g., "example.com.") // domain can contain wildcards: * (0+ chars) and ? (exactly 1 char) // ip should be a valid IPv4 or IPv6 address +// Automatically adds a corresponding PTR record for non-wildcard domains func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error { s.mu.Lock() defer s.mu.Unlock() @@ -60,6 +65,8 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error { s.aWildcards[domain] = append(s.aWildcards[domain], ip) } else { s.aRecords[domain] = append(s.aRecords[domain], ip) + // Automatically add PTR record for non-wildcard domains + s.ptrRecords[ip.String()] = domain } } else if ip.To16() != nil { // IPv6 address @@ -67,6 +74,8 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error { s.aaaaWildcards[domain] = append(s.aaaaWildcards[domain], ip) } else { s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip) + // Automatically add PTR record for non-wildcard domains + s.ptrRecords[ip.String()] = domain } } else { return &net.ParseError{Type: "IP address", Text: ip.String()} @@ -75,8 +84,30 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error { return nil } +// AddPTRRecord adds a PTR record mapping an IP address to a domain name +// ip should be a valid IPv4 or IPv6 address +// domain should be in FQDN format (e.g., "example.com.") +func (s *DNSRecordStore) AddPTRRecord(ip net.IP, domain string) error { + s.mu.Lock() + defer s.mu.Unlock() + + // Ensure domain ends with a dot (FQDN format) + if len(domain) == 0 || domain[len(domain)-1] != '.' { + domain = domain + "." + } + + // Normalize domain to lowercase FQDN + domain = strings.ToLower(dns.Fqdn(domain)) + + // Store PTR record using IP string as key + s.ptrRecords[ip.String()] = domain + + return nil +} + // RemoveRecord removes a specific DNS record mapping // If ip is nil, removes all records for the domain (including wildcards) +// Automatically removes corresponding PTR records for non-wildcard domains func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) { s.mu.Lock() defer s.mu.Unlock() @@ -98,6 +129,23 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) { delete(s.aWildcards, domain) delete(s.aaaaWildcards, domain) } else { + // For non-wildcard domains, remove PTR records for all IPs + if ips, ok := s.aRecords[domain]; ok { + for _, ipAddr := range ips { + // Only remove PTR if it points to this domain + if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain { + delete(s.ptrRecords, ipAddr.String()) + } + } + } + if ips, ok := s.aaaaRecords[domain]; ok { + for _, ipAddr := range ips { + // Only remove PTR if it points to this domain + if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain { + delete(s.ptrRecords, ipAddr.String()) + } + } + } delete(s.aRecords, domain) delete(s.aaaaRecords, domain) } @@ -119,6 +167,10 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) { if len(s.aRecords[domain]) == 0 { delete(s.aRecords, domain) } + // Automatically remove PTR record if it points to this domain + if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain { + delete(s.ptrRecords, ip.String()) + } } } } else if ip.To16() != nil { @@ -136,11 +188,23 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) { if len(s.aaaaRecords[domain]) == 0 { delete(s.aaaaRecords, domain) } + // Automatically remove PTR record if it points to this domain + if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain { + delete(s.ptrRecords, ip.String()) + } } } } } +// RemovePTRRecord removes a PTR record for an IP address +func (s *DNSRecordStore) RemovePTRRecord(ip net.IP) { + s.mu.Lock() + defer s.mu.Unlock() + + delete(s.ptrRecords, ip.String()) +} + // GetRecords returns all IP addresses for a domain and record type // First checks for exact matches, then checks wildcard patterns func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP { @@ -198,6 +262,26 @@ func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net. return records } +// GetPTRRecord returns the domain name for a PTR record query +// domain should be in reverse DNS format (e.g., "1.0.0.127.in-addr.arpa.") +func (s *DNSRecordStore) GetPTRRecord(domain string) (string, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + + // Convert reverse DNS format to IP address + ip := reverseDNSToIP(domain) + if ip == nil { + return "", false + } + + // Look up the PTR record + if ptrDomain, ok := s.ptrRecords[ip.String()]; ok { + return ptrDomain, true + } + + return "", false +} + // HasRecord checks if a domain has any records of the specified type // Checks both exact matches and wildcard patterns func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool { @@ -235,6 +319,21 @@ func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool { return false } +// HasPTRRecord checks if a PTR record exists for the given reverse DNS domain +func (s *DNSRecordStore) HasPTRRecord(domain string) bool { + s.mu.RLock() + defer s.mu.RUnlock() + + // Convert reverse DNS format to IP address + ip := reverseDNSToIP(domain) + if ip == nil { + return false + } + + _, ok := s.ptrRecords[ip.String()] + return ok +} + // Clear removes all records from the store func (s *DNSRecordStore) Clear() { s.mu.Lock() @@ -244,6 +343,7 @@ func (s *DNSRecordStore) Clear() { s.aaaaRecords = make(map[string][]net.IP) s.aWildcards = make(map[string][]net.IP) s.aaaaWildcards = make(map[string][]net.IP) + s.ptrRecords = make(map[string]string) } // removeIP is a helper function to remove a specific IP from a slice @@ -323,3 +423,75 @@ func matchWildcardInternal(pattern, domain string, pi, di int) bool { return matchWildcardInternal(pattern, domain, pi+1, di+1) } + +// reverseDNSToIP converts a reverse DNS query name to an IP address +// Supports both IPv4 (in-addr.arpa) and IPv6 (ip6.arpa) formats +func reverseDNSToIP(domain string) net.IP { + // Normalize to lowercase and ensure FQDN + domain = strings.ToLower(dns.Fqdn(domain)) + + // Check for IPv4 reverse DNS (in-addr.arpa) + if strings.HasSuffix(domain, ".in-addr.arpa.") { + // Remove the suffix + ipPart := strings.TrimSuffix(domain, ".in-addr.arpa.") + // Split by dots and reverse + parts := strings.Split(ipPart, ".") + if len(parts) != 4 { + return nil + } + // Reverse the octets + reversed := make([]string, 4) + for i := 0; i < 4; i++ { + reversed[i] = parts[3-i] + } + // Parse as IP + return net.ParseIP(strings.Join(reversed, ".")) + } + + // Check for IPv6 reverse DNS (ip6.arpa) + if strings.HasSuffix(domain, ".ip6.arpa.") { + // Remove the suffix + ipPart := strings.TrimSuffix(domain, ".ip6.arpa.") + // Split by dots and reverse + parts := strings.Split(ipPart, ".") + if len(parts) != 32 { + return nil + } + // Reverse the nibbles and group into 16-bit hex values + reversed := make([]string, 32) + for i := 0; i < 32; i++ { + reversed[i] = parts[31-i] + } + // Join into IPv6 format (groups of 4 nibbles separated by colons) + var ipv6Parts []string + for i := 0; i < 32; i += 4 { + ipv6Parts = append(ipv6Parts, reversed[i]+reversed[i+1]+reversed[i+2]+reversed[i+3]) + } + // Parse as IP + return net.ParseIP(strings.Join(ipv6Parts, ":")) + } + + return nil +} + +// IPToReverseDNS converts an IP address to reverse DNS format +// Returns the domain name for PTR queries (e.g., "1.0.0.127.in-addr.arpa.") +func IPToReverseDNS(ip net.IP) string { + if ip4 := ip.To4(); ip4 != nil { + // IPv4: reverse octets and append .in-addr.arpa. + return dns.Fqdn(fmt.Sprintf("%d.%d.%d.%d.in-addr.arpa", + ip4[3], ip4[2], ip4[1], ip4[0])) + } + + if ip6 := ip.To16(); ip6 != nil && ip.To4() == nil { + // IPv6: expand to 32 nibbles, reverse, and append .ip6.arpa. + var nibbles []string + for i := 15; i >= 0; i-- { + nibbles = append(nibbles, fmt.Sprintf("%x", ip6[i]&0x0f)) + nibbles = append(nibbles, fmt.Sprintf("%x", ip6[i]>>4)) + } + return dns.Fqdn(strings.Join(nibbles, ".") + ".ip6.arpa") + } + + return "" +} \ No newline at end of file diff --git a/dns/dns_records_test.go b/dns/dns_records_test.go index dbefaa0..eae9372 100644 --- a/dns/dns_records_test.go +++ b/dns/dns_records_test.go @@ -413,3 +413,452 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) { t.Error("Expected HasRecord to return true for mixed case wildcard match") } } + +func TestPTRRecordIPv4(t *testing.T) { + store := NewDNSRecordStore() + + // Add PTR record for IPv4 + ip := net.ParseIP("192.168.1.1") + domain := "host.example.com." + err := store.AddPTRRecord(ip, domain) + if err != nil { + t.Fatalf("Failed to add PTR record: %v", err) + } + + // Test reverse DNS lookup + reverseDomain := "1.1.168.192.in-addr.arpa." + result, ok := store.GetPTRRecord(reverseDomain) + if !ok { + t.Error("Expected PTR record to be found") + } + if result != domain { + t.Errorf("Expected domain %q, got %q", domain, result) + } + + // Test HasPTRRecord + if !store.HasPTRRecord(reverseDomain) { + t.Error("Expected HasPTRRecord to return true") + } + + // Test non-existent PTR record + _, ok = store.GetPTRRecord("2.1.168.192.in-addr.arpa.") + if ok { + t.Error("Expected PTR record not to be found for different IP") + } +} + +func TestPTRRecordIPv6(t *testing.T) { + store := NewDNSRecordStore() + + // Add PTR record for IPv6 + ip := net.ParseIP("2001:db8::1") + domain := "ipv6host.example.com." + err := store.AddPTRRecord(ip, domain) + if err != nil { + t.Fatalf("Failed to add PTR record: %v", err) + } + + // Test reverse DNS lookup + // 2001:db8::1 = 2001:0db8:0000:0000:0000:0000:0000:0001 + // Reverse: 1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa. + reverseDomain := "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa." + result, ok := store.GetPTRRecord(reverseDomain) + if !ok { + t.Error("Expected IPv6 PTR record to be found") + } + if result != domain { + t.Errorf("Expected domain %q, got %q", domain, result) + } + + // Test HasPTRRecord + if !store.HasPTRRecord(reverseDomain) { + t.Error("Expected HasPTRRecord to return true for IPv6") + } +} + +func TestRemovePTRRecord(t *testing.T) { + store := NewDNSRecordStore() + + // Add PTR record + ip := net.ParseIP("10.0.0.1") + domain := "test.example.com." + err := store.AddPTRRecord(ip, domain) + if err != nil { + t.Fatalf("Failed to add PTR record: %v", err) + } + + // Verify it exists + reverseDomain := "1.0.0.10.in-addr.arpa." + _, ok := store.GetPTRRecord(reverseDomain) + if !ok { + t.Error("Expected PTR record to exist before removal") + } + + // Remove PTR record + store.RemovePTRRecord(ip) + + // Verify it's gone + _, ok = store.GetPTRRecord(reverseDomain) + if ok { + t.Error("Expected PTR record to be removed") + } + + // Test HasPTRRecord after removal + if store.HasPTRRecord(reverseDomain) { + t.Error("Expected HasPTRRecord to return false after removal") + } +} + +func TestIPToReverseDNS(t *testing.T) { + tests := []struct { + name string + ip string + expected string + }{ + { + name: "IPv4 simple", + ip: "192.168.1.1", + expected: "1.1.168.192.in-addr.arpa.", + }, + { + name: "IPv4 localhost", + ip: "127.0.0.1", + expected: "1.0.0.127.in-addr.arpa.", + }, + { + name: "IPv4 with zeros", + ip: "10.0.0.1", + expected: "1.0.0.10.in-addr.arpa.", + }, + { + name: "IPv6 simple", + ip: "2001:db8::1", + expected: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", + }, + { + name: "IPv6 localhost", + ip: "::1", + expected: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ip := net.ParseIP(tt.ip) + if ip == nil { + t.Fatalf("Failed to parse IP: %s", tt.ip) + } + result := IPToReverseDNS(ip) + if result != tt.expected { + t.Errorf("IPToReverseDNS(%s) = %q, want %q", tt.ip, result, tt.expected) + } + }) + } +} + +func TestReverseDNSToIP(t *testing.T) { + tests := []struct { + name string + reverseDNS string + expectedIP string + shouldMatch bool + }{ + { + name: "IPv4 simple", + reverseDNS: "1.1.168.192.in-addr.arpa.", + expectedIP: "192.168.1.1", + shouldMatch: true, + }, + { + name: "IPv4 localhost", + reverseDNS: "1.0.0.127.in-addr.arpa.", + expectedIP: "127.0.0.1", + shouldMatch: true, + }, + { + name: "IPv6 simple", + reverseDNS: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", + expectedIP: "2001:db8::1", + shouldMatch: true, + }, + { + name: "Invalid IPv4 format", + reverseDNS: "1.1.168.in-addr.arpa.", + expectedIP: "", + shouldMatch: false, + }, + { + name: "Invalid IPv6 format", + reverseDNS: "1.0.0.0.ip6.arpa.", + expectedIP: "", + shouldMatch: false, + }, + { + name: "Not a reverse DNS domain", + reverseDNS: "example.com.", + expectedIP: "", + shouldMatch: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := reverseDNSToIP(tt.reverseDNS) + if tt.shouldMatch { + if result == nil { + t.Errorf("reverseDNSToIP(%q) returned nil, expected IP", tt.reverseDNS) + return + } + expectedIP := net.ParseIP(tt.expectedIP) + if !result.Equal(expectedIP) { + t.Errorf("reverseDNSToIP(%q) = %v, want %v", tt.reverseDNS, result, expectedIP) + } + } else { + if result != nil { + t.Errorf("reverseDNSToIP(%q) = %v, expected nil", tt.reverseDNS, result) + } + } + }) + } +} + +func TestPTRRecordCaseInsensitive(t *testing.T) { + store := NewDNSRecordStore() + + // Add PTR record with mixed case domain + ip := net.ParseIP("192.168.1.1") + domain := "MyHost.Example.Com" + err := store.AddPTRRecord(ip, domain) + if err != nil { + t.Fatalf("Failed to add PTR record: %v", err) + } + + // Test lookup with different cases in reverse DNS format + reverseDomain := "1.1.168.192.in-addr.arpa." + result, ok := store.GetPTRRecord(reverseDomain) + if !ok { + t.Error("Expected PTR record to be found") + } + // Domain should be normalized to lowercase + if result != "myhost.example.com." { + t.Errorf("Expected normalized domain %q, got %q", "myhost.example.com.", result) + } + + // Test with uppercase reverse DNS + reverseDomainUpper := "1.1.168.192.IN-ADDR.ARPA." + result, ok = store.GetPTRRecord(reverseDomainUpper) + if !ok { + t.Error("Expected PTR record to be found with uppercase reverse DNS") + } + if result != "myhost.example.com." { + t.Errorf("Expected normalized domain %q, got %q", "myhost.example.com.", result) + } +} + +func TestClearPTRRecords(t *testing.T) { + store := NewDNSRecordStore() + + // Add some PTR records + ip1 := net.ParseIP("192.168.1.1") + ip2 := net.ParseIP("192.168.1.2") + store.AddPTRRecord(ip1, "host1.example.com.") + store.AddPTRRecord(ip2, "host2.example.com.") + + // Add some A records too + store.AddRecord("test.example.com.", net.ParseIP("10.0.0.1")) + + // Verify PTR records exist + if !store.HasPTRRecord("1.1.168.192.in-addr.arpa.") { + t.Error("Expected PTR record to exist before clear") + } + + // Clear all records + store.Clear() + + // Verify PTR records are gone + if store.HasPTRRecord("1.1.168.192.in-addr.arpa.") { + t.Error("Expected PTR record to be cleared") + } + if store.HasPTRRecord("2.1.168.192.in-addr.arpa.") { + t.Error("Expected PTR record to be cleared") + } + + // Verify A records are also gone + if store.HasRecord("test.example.com.", RecordTypeA) { + t.Error("Expected A record to be cleared") + } +} + +func TestAutomaticPTRRecordOnAdd(t *testing.T) { + store := NewDNSRecordStore() + + // Add an A record - should automatically add PTR record + domain := "host.example.com." + ip := net.ParseIP("192.168.1.100") + err := store.AddRecord(domain, ip) + if err != nil { + t.Fatalf("Failed to add A record: %v", err) + } + + // Verify PTR record was automatically created + reverseDomain := "100.1.168.192.in-addr.arpa." + result, ok := store.GetPTRRecord(reverseDomain) + if !ok { + t.Error("Expected PTR record to be automatically created") + } + if result != domain { + t.Errorf("Expected PTR to point to %q, got %q", domain, result) + } + + // Add AAAA record - should also automatically add PTR record + domain6 := "ipv6host.example.com." + ip6 := net.ParseIP("2001:db8::1") + err = store.AddRecord(domain6, ip6) + if err != nil { + t.Fatalf("Failed to add AAAA record: %v", err) + } + + // Verify IPv6 PTR record was automatically created + reverseDomain6 := "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa." + result6, ok := store.GetPTRRecord(reverseDomain6) + if !ok { + t.Error("Expected IPv6 PTR record to be automatically created") + } + if result6 != domain6 { + t.Errorf("Expected PTR to point to %q, got %q", domain6, result6) + } +} + +func TestAutomaticPTRRecordOnRemove(t *testing.T) { + store := NewDNSRecordStore() + + // Add an A record (with automatic PTR) + domain := "host.example.com." + ip := net.ParseIP("192.168.1.100") + store.AddRecord(domain, ip) + + // Verify PTR exists + reverseDomain := "100.1.168.192.in-addr.arpa." + if !store.HasPTRRecord(reverseDomain) { + t.Error("Expected PTR record to exist after adding A record") + } + + // Remove the A record + store.RemoveRecord(domain, ip) + + // Verify PTR was automatically removed + if store.HasPTRRecord(reverseDomain) { + t.Error("Expected PTR record to be automatically removed") + } + + // Verify A record is also gone + ips := store.GetRecords(domain, RecordTypeA) + if len(ips) != 0 { + t.Errorf("Expected A record to be removed, got %d records", len(ips)) + } +} + +func TestAutomaticPTRRecordOnRemoveAll(t *testing.T) { + store := NewDNSRecordStore() + + // Add multiple IPs for the same domain + domain := "host.example.com." + ip1 := net.ParseIP("192.168.1.100") + ip2 := net.ParseIP("192.168.1.101") + store.AddRecord(domain, ip1) + store.AddRecord(domain, ip2) + + // Verify both PTR records exist + reverseDomain1 := "100.1.168.192.in-addr.arpa." + reverseDomain2 := "101.1.168.192.in-addr.arpa." + if !store.HasPTRRecord(reverseDomain1) { + t.Error("Expected first PTR record to exist") + } + if !store.HasPTRRecord(reverseDomain2) { + t.Error("Expected second PTR record to exist") + } + + // Remove all records for the domain + store.RemoveRecord(domain, nil) + + // Verify both PTR records were removed + if store.HasPTRRecord(reverseDomain1) { + t.Error("Expected first PTR record to be removed") + } + if store.HasPTRRecord(reverseDomain2) { + t.Error("Expected second PTR record to be removed") + } +} + +func TestNoPTRForWildcardRecords(t *testing.T) { + store := NewDNSRecordStore() + + // Add wildcard record - should NOT create PTR record + domain := "*.example.com." + ip := net.ParseIP("192.168.1.100") + err := store.AddRecord(domain, ip) + if err != nil { + t.Fatalf("Failed to add wildcard record: %v", err) + } + + // Verify no PTR record was created + reverseDomain := "100.1.168.192.in-addr.arpa." + _, ok := store.GetPTRRecord(reverseDomain) + if ok { + t.Error("Expected no PTR record for wildcard domain") + } + + // Verify wildcard A record exists + if !store.HasRecord("host.example.com.", RecordTypeA) { + t.Error("Expected wildcard A record to exist") + } +} + +func TestPTRRecordOverwrite(t *testing.T) { + store := NewDNSRecordStore() + + // Add first domain with IP + domain1 := "host1.example.com." + ip := net.ParseIP("192.168.1.100") + store.AddRecord(domain1, ip) + + // Verify PTR points to first domain + reverseDomain := "100.1.168.192.in-addr.arpa." + result, ok := store.GetPTRRecord(reverseDomain) + if !ok { + t.Fatal("Expected PTR record to exist") + } + if result != domain1 { + t.Errorf("Expected PTR to point to %q, got %q", domain1, result) + } + + // Add second domain with same IP - should overwrite PTR + domain2 := "host2.example.com." + store.AddRecord(domain2, ip) + + // Verify PTR now points to second domain (last one added) + result, ok = store.GetPTRRecord(reverseDomain) + if !ok { + t.Fatal("Expected PTR record to still exist") + } + if result != domain2 { + t.Errorf("Expected PTR to point to %q (overwritten), got %q", domain2, result) + } + + // Remove first domain - PTR should remain pointing to second domain + store.RemoveRecord(domain1, ip) + result, ok = store.GetPTRRecord(reverseDomain) + if !ok { + t.Error("Expected PTR record to still exist after removing first domain") + } + if result != domain2 { + t.Errorf("Expected PTR to still point to %q, got %q", domain2, result) + } + + // Remove second domain - PTR should now be gone + store.RemoveRecord(domain2, ip) + _, ok = store.GetPTRRecord(reverseDomain) + if ok { + t.Error("Expected PTR record to be removed after removing second domain") + } +}