mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 05:56:41 +00:00
167 lines
3.9 KiB
Go
167 lines
3.9 KiB
Go
package dns
|
|
|
|
import (
|
|
"net"
|
|
"sync"
|
|
|
|
"github.com/miekg/dns"
|
|
)
|
|
|
|
// RecordType represents the type of DNS record
|
|
type RecordType uint16
|
|
|
|
const (
|
|
RecordTypeA RecordType = RecordType(dns.TypeA)
|
|
RecordTypeAAAA RecordType = RecordType(dns.TypeAAAA)
|
|
)
|
|
|
|
// DNSRecordStore manages local DNS records for A and AAAA queries
|
|
type DNSRecordStore struct {
|
|
mu sync.RWMutex
|
|
aRecords map[string][]net.IP // domain -> list of IPv4 addresses
|
|
aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses
|
|
}
|
|
|
|
// NewDNSRecordStore creates a new DNS record store
|
|
func NewDNSRecordStore() *DNSRecordStore {
|
|
return &DNSRecordStore{
|
|
aRecords: make(map[string][]net.IP),
|
|
aaaaRecords: make(map[string][]net.IP),
|
|
}
|
|
}
|
|
|
|
// AddRecord adds a DNS record mapping (A or AAAA)
|
|
// domain should be in FQDN format (e.g., "example.com.")
|
|
// ip should be a valid IPv4 or IPv6 address
|
|
func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
// Ensure domain ends with a dot (FQDN format)
|
|
if len(domain) == 0 || domain[len(domain)-1] != '.' {
|
|
domain = domain + "."
|
|
}
|
|
|
|
// Normalize domain to lowercase
|
|
domain = dns.Fqdn(domain)
|
|
|
|
if ip.To4() != nil {
|
|
// IPv4 address
|
|
s.aRecords[domain] = append(s.aRecords[domain], ip)
|
|
} else if ip.To16() != nil {
|
|
// IPv6 address
|
|
s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip)
|
|
} else {
|
|
return &net.ParseError{Type: "IP address", Text: ip.String()}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// RemoveRecord removes a specific DNS record mapping
|
|
// If ip is nil, removes all records for the domain
|
|
func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
// Ensure domain ends with a dot (FQDN format)
|
|
if len(domain) == 0 || domain[len(domain)-1] != '.' {
|
|
domain = domain + "."
|
|
}
|
|
|
|
// Normalize domain to lowercase
|
|
domain = dns.Fqdn(domain)
|
|
|
|
if ip == nil {
|
|
// Remove all records for this domain
|
|
delete(s.aRecords, domain)
|
|
delete(s.aaaaRecords, domain)
|
|
return
|
|
}
|
|
|
|
if ip.To4() != nil {
|
|
// Remove specific IPv4 address
|
|
if ips, ok := s.aRecords[domain]; ok {
|
|
s.aRecords[domain] = removeIP(ips, ip)
|
|
if len(s.aRecords[domain]) == 0 {
|
|
delete(s.aRecords, domain)
|
|
}
|
|
}
|
|
} else if ip.To16() != nil {
|
|
// Remove specific IPv6 address
|
|
if ips, ok := s.aaaaRecords[domain]; ok {
|
|
s.aaaaRecords[domain] = removeIP(ips, ip)
|
|
if len(s.aaaaRecords[domain]) == 0 {
|
|
delete(s.aaaaRecords, domain)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// GetRecords returns all IP addresses for a domain and record type
|
|
func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
// Normalize domain to lowercase FQDN
|
|
domain = dns.Fqdn(domain)
|
|
|
|
var records []net.IP
|
|
switch recordType {
|
|
case RecordTypeA:
|
|
if ips, ok := s.aRecords[domain]; ok {
|
|
// Return a copy to prevent external modifications
|
|
records = make([]net.IP, len(ips))
|
|
copy(records, ips)
|
|
}
|
|
case RecordTypeAAAA:
|
|
if ips, ok := s.aaaaRecords[domain]; ok {
|
|
// Return a copy to prevent external modifications
|
|
records = make([]net.IP, len(ips))
|
|
copy(records, ips)
|
|
}
|
|
}
|
|
|
|
return records
|
|
}
|
|
|
|
// HasRecord checks if a domain has any records of the specified type
|
|
func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
// Normalize domain to lowercase FQDN
|
|
domain = dns.Fqdn(domain)
|
|
|
|
switch recordType {
|
|
case RecordTypeA:
|
|
_, ok := s.aRecords[domain]
|
|
return ok
|
|
case RecordTypeAAAA:
|
|
_, ok := s.aaaaRecords[domain]
|
|
return ok
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// Clear removes all records from the store
|
|
func (s *DNSRecordStore) Clear() {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
s.aRecords = make(map[string][]net.IP)
|
|
s.aaaaRecords = make(map[string][]net.IP)
|
|
}
|
|
|
|
// removeIP is a helper function to remove a specific IP from a slice
|
|
func removeIP(ips []net.IP, toRemove net.IP) []net.IP {
|
|
result := make([]net.IP, 0, len(ips))
|
|
for _, ip := range ips {
|
|
if !ip.Equal(toRemove) {
|
|
result = append(result, ip)
|
|
}
|
|
}
|
|
return result
|
|
}
|