mirror of
https://github.com/fosrl/olm.git
synced 2026-03-13 14:16:41 +00:00
Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
53def4e2f6 | ||
|
|
e85fd9d71e | ||
|
|
98a24960f5 | ||
|
|
e82387d515 | ||
|
|
b3cb3e1c92 | ||
|
|
3f258d3500 | ||
|
|
ae88766d85 | ||
|
|
9ae49e36d5 | ||
|
|
5ca4825800 |
@@ -447,19 +447,20 @@ func (p *DNSProxy) checkLocalRecords(query *dns.Msg, question dns.Question) *dns
|
||||
return nil
|
||||
}
|
||||
|
||||
ips := p.recordStore.GetRecords(question.Name, recordType)
|
||||
if len(ips) == 0 {
|
||||
ips, exists := p.recordStore.GetRecords(question.Name, recordType)
|
||||
if !exists {
|
||||
// Domain not found in local records, forward to upstream
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Debug("Found %d local record(s) for %s", len(ips), question.Name)
|
||||
|
||||
// Create response message
|
||||
// Create response message (NODATA if no records, otherwise with answers)
|
||||
response := new(dns.Msg)
|
||||
response.SetReply(query)
|
||||
response.Authoritative = true
|
||||
|
||||
// Add answer records
|
||||
// Add answer records (loop is a no-op if ips is empty)
|
||||
for _, ip := range ips {
|
||||
var rr dns.RR
|
||||
if question.Qtype == dns.TypeA {
|
||||
@@ -730,8 +731,9 @@ func (p *DNSProxy) RemoveDNSRecord(domain string, ip net.IP) {
|
||||
p.recordStore.RemoveRecord(domain, ip)
|
||||
}
|
||||
|
||||
// GetDNSRecords returns all IP addresses for a domain and record type
|
||||
func (p *DNSProxy) GetDNSRecords(domain string, recordType RecordType) []net.IP {
|
||||
// GetDNSRecords returns all IP addresses for a domain and record type.
|
||||
// The second return value indicates whether the domain exists.
|
||||
func (p *DNSProxy) GetDNSRecords(domain string, recordType RecordType) ([]net.IP, bool) {
|
||||
return p.recordStore.GetRecords(domain, recordType)
|
||||
}
|
||||
|
||||
|
||||
178
dns/dns_proxy_test.go
Normal file
178
dns/dns_proxy_test.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestCheckLocalRecordsNODATAForAAAA(t *testing.T) {
|
||||
proxy := &DNSProxy{
|
||||
recordStore: NewDNSRecordStore(),
|
||||
}
|
||||
|
||||
// Add an A record for a domain (no AAAA record)
|
||||
ip := net.ParseIP("10.0.0.1")
|
||||
err := proxy.recordStore.AddRecord("myservice.internal", ip)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add A record: %v", err)
|
||||
}
|
||||
|
||||
// Query AAAA for domain with only A record - should return NODATA
|
||||
query := new(dns.Msg)
|
||||
query.SetQuestion("myservice.internal.", dns.TypeAAAA)
|
||||
response := proxy.checkLocalRecords(query, query.Question[0])
|
||||
|
||||
if response == nil {
|
||||
t.Fatal("Expected NODATA response, got nil (would forward to upstream)")
|
||||
}
|
||||
if response.Rcode != dns.RcodeSuccess {
|
||||
t.Errorf("Expected Rcode NOERROR (0), got %d", response.Rcode)
|
||||
}
|
||||
if len(response.Answer) != 0 {
|
||||
t.Errorf("Expected empty answer section for NODATA, got %d answers", len(response.Answer))
|
||||
}
|
||||
if !response.Authoritative {
|
||||
t.Error("Expected response to be authoritative")
|
||||
}
|
||||
|
||||
// Query A for same domain - should return the record
|
||||
query = new(dns.Msg)
|
||||
query.SetQuestion("myservice.internal.", dns.TypeA)
|
||||
response = proxy.checkLocalRecords(query, query.Question[0])
|
||||
|
||||
if response == nil {
|
||||
t.Fatal("Expected response with A record, got nil")
|
||||
}
|
||||
if len(response.Answer) != 1 {
|
||||
t.Fatalf("Expected 1 answer, got %d", len(response.Answer))
|
||||
}
|
||||
aRecord, ok := response.Answer[0].(*dns.A)
|
||||
if !ok {
|
||||
t.Fatal("Expected A record in answer")
|
||||
}
|
||||
if !aRecord.A.Equal(ip.To4()) {
|
||||
t.Errorf("Expected IP %v, got %v", ip.To4(), aRecord.A)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckLocalRecordsNODATAForA(t *testing.T) {
|
||||
proxy := &DNSProxy{
|
||||
recordStore: NewDNSRecordStore(),
|
||||
}
|
||||
|
||||
// Add an AAAA record for a domain (no A record)
|
||||
ip := net.ParseIP("2001:db8::1")
|
||||
err := proxy.recordStore.AddRecord("ipv6only.internal", ip)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add AAAA record: %v", err)
|
||||
}
|
||||
|
||||
// Query A for domain with only AAAA record - should return NODATA
|
||||
query := new(dns.Msg)
|
||||
query.SetQuestion("ipv6only.internal.", dns.TypeA)
|
||||
response := proxy.checkLocalRecords(query, query.Question[0])
|
||||
|
||||
if response == nil {
|
||||
t.Fatal("Expected NODATA response, got nil")
|
||||
}
|
||||
if response.Rcode != dns.RcodeSuccess {
|
||||
t.Errorf("Expected Rcode NOERROR (0), got %d", response.Rcode)
|
||||
}
|
||||
if len(response.Answer) != 0 {
|
||||
t.Errorf("Expected empty answer section, got %d answers", len(response.Answer))
|
||||
}
|
||||
if !response.Authoritative {
|
||||
t.Error("Expected response to be authoritative")
|
||||
}
|
||||
|
||||
// Query AAAA for same domain - should return the record
|
||||
query = new(dns.Msg)
|
||||
query.SetQuestion("ipv6only.internal.", dns.TypeAAAA)
|
||||
response = proxy.checkLocalRecords(query, query.Question[0])
|
||||
|
||||
if response == nil {
|
||||
t.Fatal("Expected response with AAAA record, got nil")
|
||||
}
|
||||
if len(response.Answer) != 1 {
|
||||
t.Fatalf("Expected 1 answer, got %d", len(response.Answer))
|
||||
}
|
||||
aaaaRecord, ok := response.Answer[0].(*dns.AAAA)
|
||||
if !ok {
|
||||
t.Fatal("Expected AAAA record in answer")
|
||||
}
|
||||
if !aaaaRecord.AAAA.Equal(ip) {
|
||||
t.Errorf("Expected IP %v, got %v", ip, aaaaRecord.AAAA)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckLocalRecordsNonExistentDomain(t *testing.T) {
|
||||
proxy := &DNSProxy{
|
||||
recordStore: NewDNSRecordStore(),
|
||||
}
|
||||
|
||||
// Add a record so the store isn't empty
|
||||
err := proxy.recordStore.AddRecord("exists.internal", net.ParseIP("10.0.0.1"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add record: %v", err)
|
||||
}
|
||||
|
||||
// Query A for non-existent domain - should return nil (forward to upstream)
|
||||
query := new(dns.Msg)
|
||||
query.SetQuestion("unknown.internal.", dns.TypeA)
|
||||
response := proxy.checkLocalRecords(query, query.Question[0])
|
||||
|
||||
if response != nil {
|
||||
t.Error("Expected nil for non-existent domain, got response")
|
||||
}
|
||||
|
||||
// Query AAAA for non-existent domain - should also return nil
|
||||
query = new(dns.Msg)
|
||||
query.SetQuestion("unknown.internal.", dns.TypeAAAA)
|
||||
response = proxy.checkLocalRecords(query, query.Question[0])
|
||||
|
||||
if response != nil {
|
||||
t.Error("Expected nil for non-existent domain, got response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckLocalRecordsNODATAWildcard(t *testing.T) {
|
||||
proxy := &DNSProxy{
|
||||
recordStore: NewDNSRecordStore(),
|
||||
}
|
||||
|
||||
// Add a wildcard A record (no AAAA)
|
||||
ip := net.ParseIP("10.0.0.1")
|
||||
err := proxy.recordStore.AddRecord("*.wildcard.internal", ip)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add wildcard A record: %v", err)
|
||||
}
|
||||
|
||||
// Query AAAA for wildcard-matched domain - should return NODATA
|
||||
query := new(dns.Msg)
|
||||
query.SetQuestion("host.wildcard.internal.", dns.TypeAAAA)
|
||||
response := proxy.checkLocalRecords(query, query.Question[0])
|
||||
|
||||
if response == nil {
|
||||
t.Fatal("Expected NODATA response for wildcard match, got nil")
|
||||
}
|
||||
if response.Rcode != dns.RcodeSuccess {
|
||||
t.Errorf("Expected Rcode NOERROR (0), got %d", response.Rcode)
|
||||
}
|
||||
if len(response.Answer) != 0 {
|
||||
t.Errorf("Expected empty answer section, got %d answers", len(response.Answer))
|
||||
}
|
||||
|
||||
// Query A for wildcard-matched domain - should return the record
|
||||
query = new(dns.Msg)
|
||||
query.SetQuestion("host.wildcard.internal.", dns.TypeA)
|
||||
response = proxy.checkLocalRecords(query, query.Question[0])
|
||||
|
||||
if response == nil {
|
||||
t.Fatal("Expected response with A record, got nil")
|
||||
}
|
||||
if len(response.Answer) != 1 {
|
||||
t.Fatalf("Expected 1 answer, got %d", len(response.Answer))
|
||||
}
|
||||
}
|
||||
@@ -18,24 +18,27 @@ const (
|
||||
RecordTypePTR RecordType = RecordType(dns.TypePTR)
|
||||
)
|
||||
|
||||
// DNSRecordStore manages local DNS records for A, AAAA, and PTR queries
|
||||
// recordSet holds A and AAAA records for a single domain or wildcard pattern
|
||||
type recordSet struct {
|
||||
A []net.IP
|
||||
AAAA []net.IP
|
||||
}
|
||||
|
||||
// DNSRecordStore manages local DNS records for A, AAAA, and PTR queries.
|
||||
// Exact domains are stored in a map; wildcard patterns are in a separate map.
|
||||
type DNSRecordStore struct {
|
||||
mu sync.RWMutex
|
||||
aRecords map[string][]net.IP // domain -> list of IPv4 addresses
|
||||
aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses
|
||||
aWildcards map[string][]net.IP // wildcard pattern -> list of IPv4 addresses
|
||||
aaaaWildcards map[string][]net.IP // wildcard pattern -> list of IPv6 addresses
|
||||
ptrRecords map[string]string // IP address string -> domain name
|
||||
mu sync.RWMutex
|
||||
exact map[string]*recordSet // normalized FQDN -> A/AAAA records
|
||||
wildcards map[string]*recordSet // wildcard pattern -> A/AAAA records
|
||||
ptrRecords map[string]string // IP address string -> domain name
|
||||
}
|
||||
|
||||
// NewDNSRecordStore creates a new DNS record store
|
||||
func NewDNSRecordStore() *DNSRecordStore {
|
||||
return &DNSRecordStore{
|
||||
aRecords: make(map[string][]net.IP),
|
||||
aaaaRecords: make(map[string][]net.IP),
|
||||
aWildcards: make(map[string][]net.IP),
|
||||
aaaaWildcards: make(map[string][]net.IP),
|
||||
ptrRecords: make(map[string]string),
|
||||
exact: make(map[string]*recordSet),
|
||||
wildcards: make(map[string]*recordSet),
|
||||
ptrRecords: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,39 +51,37 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Ensure domain ends with a dot (FQDN format)
|
||||
if len(domain) == 0 || domain[len(domain)-1] != '.' {
|
||||
domain = domain + "."
|
||||
}
|
||||
|
||||
// Normalize domain to lowercase FQDN
|
||||
domain = strings.ToLower(dns.Fqdn(domain))
|
||||
|
||||
// Check if domain contains wildcards
|
||||
isWildcard := strings.ContainsAny(domain, "*?")
|
||||
|
||||
if ip.To4() != nil {
|
||||
// IPv4 address
|
||||
if isWildcard {
|
||||
s.aWildcards[domain] = append(s.aWildcards[domain], ip)
|
||||
} else {
|
||||
s.aRecords[domain] = append(s.aRecords[domain], ip)
|
||||
// Automatically add PTR record for non-wildcard domains
|
||||
s.ptrRecords[ip.String()] = domain
|
||||
}
|
||||
} else if ip.To16() != nil {
|
||||
// IPv6 address
|
||||
if isWildcard {
|
||||
s.aaaaWildcards[domain] = append(s.aaaaWildcards[domain], ip)
|
||||
} else {
|
||||
s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip)
|
||||
// Automatically add PTR record for non-wildcard domains
|
||||
s.ptrRecords[ip.String()] = domain
|
||||
}
|
||||
} else {
|
||||
isV4 := ip.To4() != nil
|
||||
if !isV4 && ip.To16() == nil {
|
||||
return &net.ParseError{Type: "IP address", Text: ip.String()}
|
||||
}
|
||||
|
||||
// Choose the appropriate map based on whether this is a wildcard
|
||||
m := s.exact
|
||||
if isWildcard {
|
||||
m = s.wildcards
|
||||
}
|
||||
|
||||
if m[domain] == nil {
|
||||
m[domain] = &recordSet{}
|
||||
}
|
||||
rs := m[domain]
|
||||
if isV4 {
|
||||
rs.A = append(rs.A, ip)
|
||||
} else {
|
||||
rs.AAAA = append(rs.AAAA, ip)
|
||||
}
|
||||
|
||||
// Add PTR record for non-wildcard domains
|
||||
if !isWildcard {
|
||||
s.ptrRecords[ip.String()] = domain
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -112,89 +113,62 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Ensure domain ends with a dot (FQDN format)
|
||||
if len(domain) == 0 || domain[len(domain)-1] != '.' {
|
||||
domain = domain + "."
|
||||
}
|
||||
|
||||
// Normalize domain to lowercase FQDN
|
||||
domain = strings.ToLower(dns.Fqdn(domain))
|
||||
|
||||
// Check if domain contains wildcards
|
||||
isWildcard := strings.ContainsAny(domain, "*?")
|
||||
|
||||
if ip == nil {
|
||||
// Remove all records for this domain
|
||||
if isWildcard {
|
||||
delete(s.aWildcards, domain)
|
||||
delete(s.aaaaWildcards, domain)
|
||||
} else {
|
||||
// For non-wildcard domains, remove PTR records for all IPs
|
||||
if ips, ok := s.aRecords[domain]; ok {
|
||||
for _, ipAddr := range ips {
|
||||
// Only remove PTR if it points to this domain
|
||||
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
|
||||
delete(s.ptrRecords, ipAddr.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
if ips, ok := s.aaaaRecords[domain]; ok {
|
||||
for _, ipAddr := range ips {
|
||||
// Only remove PTR if it points to this domain
|
||||
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
|
||||
delete(s.ptrRecords, ipAddr.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
delete(s.aRecords, domain)
|
||||
delete(s.aaaaRecords, domain)
|
||||
}
|
||||
// Choose the appropriate map
|
||||
m := s.exact
|
||||
if isWildcard {
|
||||
m = s.wildcards
|
||||
}
|
||||
|
||||
rs := m[domain]
|
||||
if rs == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if ip == nil {
|
||||
// Remove all records for this domain
|
||||
if !isWildcard {
|
||||
for _, ipAddr := range rs.A {
|
||||
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
|
||||
delete(s.ptrRecords, ipAddr.String())
|
||||
}
|
||||
}
|
||||
for _, ipAddr := range rs.AAAA {
|
||||
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
|
||||
delete(s.ptrRecords, ipAddr.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
delete(m, domain)
|
||||
return
|
||||
}
|
||||
|
||||
// Remove specific IP
|
||||
if ip.To4() != nil {
|
||||
// Remove specific IPv4 address
|
||||
if isWildcard {
|
||||
if ips, ok := s.aWildcards[domain]; ok {
|
||||
s.aWildcards[domain] = removeIP(ips, ip)
|
||||
if len(s.aWildcards[domain]) == 0 {
|
||||
delete(s.aWildcards, domain)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if ips, ok := s.aRecords[domain]; ok {
|
||||
s.aRecords[domain] = removeIP(ips, ip)
|
||||
if len(s.aRecords[domain]) == 0 {
|
||||
delete(s.aRecords, domain)
|
||||
}
|
||||
// Automatically remove PTR record if it points to this domain
|
||||
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
|
||||
delete(s.ptrRecords, ip.String())
|
||||
}
|
||||
rs.A = removeIP(rs.A, ip)
|
||||
if !isWildcard {
|
||||
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
|
||||
delete(s.ptrRecords, ip.String())
|
||||
}
|
||||
}
|
||||
} else if ip.To16() != nil {
|
||||
// Remove specific IPv6 address
|
||||
if isWildcard {
|
||||
if ips, ok := s.aaaaWildcards[domain]; ok {
|
||||
s.aaaaWildcards[domain] = removeIP(ips, ip)
|
||||
if len(s.aaaaWildcards[domain]) == 0 {
|
||||
delete(s.aaaaWildcards, domain)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if ips, ok := s.aaaaRecords[domain]; ok {
|
||||
s.aaaaRecords[domain] = removeIP(ips, ip)
|
||||
if len(s.aaaaRecords[domain]) == 0 {
|
||||
delete(s.aaaaRecords, domain)
|
||||
}
|
||||
// Automatically remove PTR record if it points to this domain
|
||||
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
|
||||
delete(s.ptrRecords, ip.String())
|
||||
}
|
||||
} else {
|
||||
rs.AAAA = removeIP(rs.AAAA, ip)
|
||||
if !isWildcard {
|
||||
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
|
||||
delete(s.ptrRecords, ip.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up empty record sets
|
||||
if len(rs.A) == 0 && len(rs.AAAA) == 0 {
|
||||
delete(m, domain)
|
||||
}
|
||||
}
|
||||
|
||||
// RemovePTRRecord removes a PTR record for an IP address
|
||||
@@ -205,61 +179,56 @@ func (s *DNSRecordStore) RemovePTRRecord(ip net.IP) {
|
||||
delete(s.ptrRecords, ip.String())
|
||||
}
|
||||
|
||||
// GetRecords returns all IP addresses for a domain and record type
|
||||
// First checks for exact matches, then checks wildcard patterns
|
||||
func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP {
|
||||
// GetRecords returns all IP addresses for a domain and record type.
|
||||
// The second return value indicates whether the domain exists at all
|
||||
// (true = domain exists, use NODATA if no records; false = NXDOMAIN).
|
||||
func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) ([]net.IP, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// Normalize domain to lowercase FQDN
|
||||
domain = strings.ToLower(dns.Fqdn(domain))
|
||||
|
||||
var records []net.IP
|
||||
switch recordType {
|
||||
case RecordTypeA:
|
||||
// Check exact match first
|
||||
if ips, ok := s.aRecords[domain]; ok {
|
||||
// Return a copy to prevent external modifications
|
||||
records = make([]net.IP, len(ips))
|
||||
copy(records, ips)
|
||||
return records
|
||||
// Check exact match first
|
||||
if rs, exists := s.exact[domain]; exists {
|
||||
var ips []net.IP
|
||||
if recordType == RecordTypeA {
|
||||
ips = rs.A
|
||||
} else {
|
||||
ips = rs.AAAA
|
||||
}
|
||||
// Check wildcard patterns
|
||||
for pattern, ips := range s.aWildcards {
|
||||
if matchWildcard(pattern, domain) {
|
||||
records = append(records, ips...)
|
||||
}
|
||||
}
|
||||
if len(records) > 0 {
|
||||
// Return a copy
|
||||
result := make([]net.IP, len(records))
|
||||
copy(result, records)
|
||||
return result
|
||||
if len(ips) > 0 {
|
||||
out := make([]net.IP, len(ips))
|
||||
copy(out, ips)
|
||||
return out, true
|
||||
}
|
||||
// Domain exists but no records of this type
|
||||
return nil, true
|
||||
}
|
||||
|
||||
case RecordTypeAAAA:
|
||||
// Check exact match first
|
||||
if ips, ok := s.aaaaRecords[domain]; ok {
|
||||
// Return a copy to prevent external modifications
|
||||
records = make([]net.IP, len(ips))
|
||||
copy(records, ips)
|
||||
return records
|
||||
// Check wildcard matches
|
||||
var records []net.IP
|
||||
matched := false
|
||||
for pattern, rs := range s.wildcards {
|
||||
if !matchWildcard(pattern, domain) {
|
||||
continue
|
||||
}
|
||||
// Check wildcard patterns
|
||||
for pattern, ips := range s.aaaaWildcards {
|
||||
if matchWildcard(pattern, domain) {
|
||||
records = append(records, ips...)
|
||||
}
|
||||
}
|
||||
if len(records) > 0 {
|
||||
// Return a copy
|
||||
result := make([]net.IP, len(records))
|
||||
copy(result, records)
|
||||
return result
|
||||
matched = true
|
||||
if recordType == RecordTypeA {
|
||||
records = append(records, rs.A...)
|
||||
} else {
|
||||
records = append(records, rs.AAAA...)
|
||||
}
|
||||
}
|
||||
|
||||
return records
|
||||
if !matched {
|
||||
return nil, false
|
||||
}
|
||||
if len(records) == 0 {
|
||||
return nil, true
|
||||
}
|
||||
out := make([]net.IP, len(records))
|
||||
copy(out, records)
|
||||
return out, true
|
||||
}
|
||||
|
||||
// GetPTRRecord returns the domain name for a PTR record query
|
||||
@@ -288,34 +257,30 @@ func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// Normalize domain to lowercase FQDN
|
||||
domain = strings.ToLower(dns.Fqdn(domain))
|
||||
|
||||
switch recordType {
|
||||
case RecordTypeA:
|
||||
// Check exact match
|
||||
if _, ok := s.aRecords[domain]; ok {
|
||||
// Check exact match
|
||||
if rs, exists := s.exact[domain]; exists {
|
||||
if recordType == RecordTypeA && len(rs.A) > 0 {
|
||||
return true
|
||||
}
|
||||
// Check wildcard patterns
|
||||
for pattern := range s.aWildcards {
|
||||
if matchWildcard(pattern, domain) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
case RecordTypeAAAA:
|
||||
// Check exact match
|
||||
if _, ok := s.aaaaRecords[domain]; ok {
|
||||
if recordType == RecordTypeAAAA && len(rs.AAAA) > 0 {
|
||||
return true
|
||||
}
|
||||
// Check wildcard patterns
|
||||
for pattern := range s.aaaaWildcards {
|
||||
if matchWildcard(pattern, domain) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check wildcard matches
|
||||
for pattern, rs := range s.wildcards {
|
||||
if !matchWildcard(pattern, domain) {
|
||||
continue
|
||||
}
|
||||
if recordType == RecordTypeA && len(rs.A) > 0 {
|
||||
return true
|
||||
}
|
||||
if recordType == RecordTypeAAAA && len(rs.AAAA) > 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -339,10 +304,8 @@ func (s *DNSRecordStore) Clear() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.aRecords = make(map[string][]net.IP)
|
||||
s.aaaaRecords = make(map[string][]net.IP)
|
||||
s.aWildcards = make(map[string][]net.IP)
|
||||
s.aaaaWildcards = make(map[string][]net.IP)
|
||||
s.exact = make(map[string]*recordSet)
|
||||
s.wildcards = make(map[string]*recordSet)
|
||||
s.ptrRecords = make(map[string]string)
|
||||
}
|
||||
|
||||
@@ -494,4 +457,4 @@ func IPToReverseDNS(ip net.IP) string {
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
@@ -183,25 +183,34 @@ func TestDNSRecordStoreWildcard(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test exact match takes precedence
|
||||
ips := store.GetRecords("exact.autoco.internal.", RecordTypeA)
|
||||
ips, exists := store.GetRecords("exact.autoco.internal.", RecordTypeA)
|
||||
if !exists {
|
||||
t.Error("Expected domain to exist")
|
||||
}
|
||||
if len(ips) != 1 {
|
||||
t.Errorf("Expected 1 IP for exact match, got %d", len(ips))
|
||||
}
|
||||
if !ips[0].Equal(exactIP) {
|
||||
if len(ips) > 0 && !ips[0].Equal(exactIP) {
|
||||
t.Errorf("Expected exact IP %v, got %v", exactIP, ips[0])
|
||||
}
|
||||
|
||||
// Test wildcard match
|
||||
ips = store.GetRecords("host.autoco.internal.", RecordTypeA)
|
||||
ips, exists = store.GetRecords("host.autoco.internal.", RecordTypeA)
|
||||
if !exists {
|
||||
t.Error("Expected wildcard match to exist")
|
||||
}
|
||||
if len(ips) != 1 {
|
||||
t.Errorf("Expected 1 IP for wildcard match, got %d", len(ips))
|
||||
}
|
||||
if !ips[0].Equal(wildcardIP) {
|
||||
if len(ips) > 0 && !ips[0].Equal(wildcardIP) {
|
||||
t.Errorf("Expected wildcard IP %v, got %v", wildcardIP, ips[0])
|
||||
}
|
||||
|
||||
// Test non-match (base domain)
|
||||
ips = store.GetRecords("autoco.internal.", RecordTypeA)
|
||||
ips, exists = store.GetRecords("autoco.internal.", RecordTypeA)
|
||||
if exists {
|
||||
t.Error("Expected base domain to not exist")
|
||||
}
|
||||
if len(ips) != 0 {
|
||||
t.Errorf("Expected 0 IPs for base domain, got %d", len(ips))
|
||||
}
|
||||
@@ -218,7 +227,10 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test matching domain
|
||||
ips := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA)
|
||||
ips, exists := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA)
|
||||
if !exists {
|
||||
t.Error("Expected complex wildcard match to exist")
|
||||
}
|
||||
if len(ips) != 1 {
|
||||
t.Errorf("Expected 1 IP for complex wildcard match, got %d", len(ips))
|
||||
}
|
||||
@@ -227,13 +239,19 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test non-matching domain (missing prefix)
|
||||
ips = store.GetRecords("host-01.autoco.internal.", RecordTypeA)
|
||||
ips, exists = store.GetRecords("host-01.autoco.internal.", RecordTypeA)
|
||||
if exists {
|
||||
t.Error("Expected domain without prefix to not exist")
|
||||
}
|
||||
if len(ips) != 0 {
|
||||
t.Errorf("Expected 0 IPs for domain without prefix, got %d", len(ips))
|
||||
}
|
||||
|
||||
// Test non-matching domain (wrong ? position)
|
||||
ips = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA)
|
||||
ips, exists = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA)
|
||||
if exists {
|
||||
t.Error("Expected domain with wrong ? match to not exist")
|
||||
}
|
||||
if len(ips) != 0 {
|
||||
t.Errorf("Expected 0 IPs for domain with wrong ? match, got %d", len(ips))
|
||||
}
|
||||
@@ -250,7 +268,10 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify it exists
|
||||
ips := store.GetRecords("host.autoco.internal.", RecordTypeA)
|
||||
ips, exists := store.GetRecords("host.autoco.internal.", RecordTypeA)
|
||||
if !exists {
|
||||
t.Error("Expected domain to exist before removal")
|
||||
}
|
||||
if len(ips) != 1 {
|
||||
t.Errorf("Expected 1 IP before removal, got %d", len(ips))
|
||||
}
|
||||
@@ -259,7 +280,10 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) {
|
||||
store.RemoveRecord("*.autoco.internal", nil)
|
||||
|
||||
// Verify it's gone
|
||||
ips = store.GetRecords("host.autoco.internal.", RecordTypeA)
|
||||
ips, exists = store.GetRecords("host.autoco.internal.", RecordTypeA)
|
||||
if exists {
|
||||
t.Error("Expected domain to not exist after removal")
|
||||
}
|
||||
if len(ips) != 0 {
|
||||
t.Errorf("Expected 0 IPs after removal, got %d", len(ips))
|
||||
}
|
||||
@@ -290,19 +314,19 @@ func TestDNSRecordStoreMultipleWildcards(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test domain matching only the prod pattern and the broad pattern
|
||||
ips := store.GetRecords("host.prod.autoco.internal.", RecordTypeA)
|
||||
ips, _ := store.GetRecords("host.prod.autoco.internal.", RecordTypeA)
|
||||
if len(ips) != 2 {
|
||||
t.Errorf("Expected 2 IPs (prod + broad), got %d", len(ips))
|
||||
}
|
||||
|
||||
// Test domain matching only the dev pattern and the broad pattern
|
||||
ips = store.GetRecords("service.dev.autoco.internal.", RecordTypeA)
|
||||
ips, _ = store.GetRecords("service.dev.autoco.internal.", RecordTypeA)
|
||||
if len(ips) != 2 {
|
||||
t.Errorf("Expected 2 IPs (dev + broad), got %d", len(ips))
|
||||
}
|
||||
|
||||
// Test domain matching only the broad pattern
|
||||
ips = store.GetRecords("host.test.autoco.internal.", RecordTypeA)
|
||||
ips, _ = store.GetRecords("host.test.autoco.internal.", RecordTypeA)
|
||||
if len(ips) != 1 {
|
||||
t.Errorf("Expected 1 IP (broad only), got %d", len(ips))
|
||||
}
|
||||
@@ -319,7 +343,7 @@ func TestDNSRecordStoreIPv6Wildcard(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test wildcard match for IPv6
|
||||
ips := store.GetRecords("host.autoco.internal.", RecordTypeAAAA)
|
||||
ips, _ := store.GetRecords("host.autoco.internal.", RecordTypeAAAA)
|
||||
if len(ips) != 1 {
|
||||
t.Errorf("Expected 1 IPv6 for wildcard match, got %d", len(ips))
|
||||
}
|
||||
@@ -368,7 +392,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, domain := range testCases {
|
||||
ips := store.GetRecords(domain, RecordTypeA)
|
||||
ips, _ := store.GetRecords(domain, RecordTypeA)
|
||||
if len(ips) != 1 {
|
||||
t.Errorf("Expected 1 IP for domain %q, got %d", domain, len(ips))
|
||||
}
|
||||
@@ -392,7 +416,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, domain := range wildcardTestCases {
|
||||
ips := store.GetRecords(domain, RecordTypeA)
|
||||
ips, _ := store.GetRecords(domain, RecordTypeA)
|
||||
if len(ips) != 1 {
|
||||
t.Errorf("Expected 1 IP for wildcard domain %q, got %d", domain, len(ips))
|
||||
}
|
||||
@@ -403,7 +427,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
|
||||
|
||||
// Test removal with different case
|
||||
store.RemoveRecord("MYHOST.AUTOCO.INTERNAL", nil)
|
||||
ips := store.GetRecords("myhost.autoco.internal.", RecordTypeA)
|
||||
ips, _ := store.GetRecords("myhost.autoco.internal.", RecordTypeA)
|
||||
if len(ips) != 0 {
|
||||
t.Errorf("Expected 0 IPs after removal, got %d", len(ips))
|
||||
}
|
||||
@@ -752,7 +776,7 @@ func TestAutomaticPTRRecordOnRemove(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify A record is also gone
|
||||
ips := store.GetRecords(domain, RecordTypeA)
|
||||
ips, _ := store.GetRecords(domain, RecordTypeA)
|
||||
if len(ips) != 0 {
|
||||
t.Errorf("Expected A record to be removed, got %d records", len(ips))
|
||||
}
|
||||
|
||||
14
go.mod
14
go.mod
@@ -1,14 +1,14 @@
|
||||
module github.com/fosrl/olm
|
||||
|
||||
go 1.25
|
||||
go 1.25.0
|
||||
|
||||
require (
|
||||
github.com/Microsoft/go-winio v0.6.2
|
||||
github.com/fosrl/newt v1.9.0
|
||||
github.com/fosrl/newt v1.10.3
|
||||
github.com/godbus/dbus/v5 v5.2.2
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/miekg/dns v1.1.70
|
||||
golang.org/x/sys v0.40.0
|
||||
golang.org/x/sys v0.41.0
|
||||
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
|
||||
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c
|
||||
@@ -20,13 +20,13 @@ require (
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/vishvananda/netlink v1.3.1 // indirect
|
||||
github.com/vishvananda/netns v0.0.5 // indirect
|
||||
golang.org/x/crypto v0.46.0 // indirect
|
||||
golang.org/x/crypto v0.48.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect
|
||||
golang.org/x/mod v0.31.0 // indirect
|
||||
golang.org/x/net v0.48.0 // indirect
|
||||
golang.org/x/mod v0.32.0 // indirect
|
||||
golang.org/x/net v0.51.0 // indirect
|
||||
golang.org/x/sync v0.19.0 // indirect
|
||||
golang.org/x/time v0.12.0 // indirect
|
||||
golang.org/x/tools v0.40.0 // indirect
|
||||
golang.org/x/tools v0.41.0 // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
|
||||
)
|
||||
|
||||
24
go.sum
24
go.sum
@@ -1,7 +1,7 @@
|
||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/fosrl/newt v1.9.0 h1:66eJMo6fA+YcBTbddxTfNJXNQo1WWKzmn6zPRP5kSDE=
|
||||
github.com/fosrl/newt v1.9.0/go.mod h1:d1+yYMnKqg4oLqAM9zdbjthjj2FQEVouiACjqU468ck=
|
||||
github.com/fosrl/newt v1.10.3 h1:JO9gFK9LP/w2EeDIn4wU+jKggAFPo06hX5hxFSETqcw=
|
||||
github.com/fosrl/newt v1.10.3/go.mod h1:iYuuCAG7iabheiogMOX87r61uQN31S39nKxMKRuLS+s=
|
||||
github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ=
|
||||
github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
|
||||
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
|
||||
@@ -16,24 +16,24 @@ github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW
|
||||
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
|
||||
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
|
||||
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
|
||||
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
|
||||
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
|
||||
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
|
||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0=
|
||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0=
|
||||
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
|
||||
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
|
||||
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
|
||||
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
|
||||
golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
|
||||
golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
|
||||
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
|
||||
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
|
||||
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
||||
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
|
||||
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
|
||||
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
|
||||
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A=
|
||||
|
||||
@@ -168,10 +168,17 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) {
|
||||
SharedBind: o.sharedBind,
|
||||
WSClient: o.websocket,
|
||||
APIServer: o.apiServer,
|
||||
PublicDNS: o.tunnelConfig.PublicDNS,
|
||||
})
|
||||
|
||||
for i := range wgData.Sites {
|
||||
site := wgData.Sites[i]
|
||||
|
||||
if site.PublicKey == "" {
|
||||
logger.Warn("Skipping site %d (%s): no public key available (site may not be connected)", site.SiteId, site.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
var siteEndpoint string
|
||||
// here we are going to take the relay endpoint if it exists which means we requested a relay for this peer
|
||||
if site.RelayEndpoint != "" {
|
||||
|
||||
21
olm/olm.go
21
olm/olm.go
@@ -31,7 +31,7 @@ type Olm struct {
|
||||
privateKey wgtypes.Key
|
||||
logFile *os.File
|
||||
|
||||
registered bool
|
||||
registered bool
|
||||
tunnelRunning bool
|
||||
|
||||
uapiListener net.Listener
|
||||
@@ -111,7 +111,7 @@ func (o *Olm) initTunnelInfo(clientID string) error {
|
||||
logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount())
|
||||
|
||||
// Create the holepunch manager
|
||||
o.holePunchManager = holepunch.NewManager(sharedBind, clientID, "olm", privateKey.PublicKey().String())
|
||||
o.holePunchManager = holepunch.NewManager(sharedBind, clientID, "olm", privateKey.PublicKey().String(), o.tunnelConfig.PublicDNS)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -222,7 +222,7 @@ func (o *Olm) registerAPICallbacks() {
|
||||
tunnelConfig.MTU = 1420
|
||||
}
|
||||
if req.DNS == "" {
|
||||
tunnelConfig.DNS = "9.9.9.9"
|
||||
tunnelConfig.DNS = "8.8.8.8"
|
||||
}
|
||||
// DNSProxyIP has no default - it must be provided if DNS proxy is desired
|
||||
// UpstreamDNS defaults to 8.8.8.8 if not provided
|
||||
@@ -292,16 +292,23 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
||||
logger.Info("Tunnel already running")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// debug print out the whole config
|
||||
logger.Debug("Starting tunnel with config: %+v", config)
|
||||
|
||||
o.tunnelRunning = true // Also set it here in case it is called externally
|
||||
o.tunnelConfig = config
|
||||
|
||||
// TODO: we are hardcoding this for now but we should really pull it from the current config of the system
|
||||
if o.tunnelConfig.DNS != "" {
|
||||
o.tunnelConfig.PublicDNS = []string{o.tunnelConfig.DNS + ":53"}
|
||||
} else {
|
||||
o.tunnelConfig.PublicDNS = []string{"8.8.8.8:53"}
|
||||
}
|
||||
|
||||
// Reset terminated status when tunnel starts
|
||||
o.apiServer.SetTerminated(false)
|
||||
|
||||
|
||||
fingerprint := config.InitialFingerprint
|
||||
if fingerprint == nil {
|
||||
fingerprint = make(map[string]any)
|
||||
@@ -313,7 +320,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
||||
}
|
||||
|
||||
o.SetFingerprint(fingerprint)
|
||||
o.SetPostures(postures)
|
||||
o.SetPostures(postures)
|
||||
|
||||
// Create a cancellable context for this tunnel process
|
||||
tunnelCtx, cancel := context.WithCancel(o.olmCtx)
|
||||
@@ -387,7 +394,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
||||
|
||||
if o.registered {
|
||||
o.websocket.StartPingMonitor()
|
||||
|
||||
|
||||
logger.Debug("Already registered, skipping registration")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -37,6 +37,11 @@ func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) {
|
||||
return
|
||||
}
|
||||
|
||||
if siteConfig.PublicKey == "" {
|
||||
logger.Warn("Skipping add-peer for site %d (%s): no public key available (site may not be connected)", siteConfig.SiteId, siteConfig.Name)
|
||||
return
|
||||
}
|
||||
|
||||
_ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it
|
||||
|
||||
if err := o.peerManager.AddPeer(siteConfig); err != nil {
|
||||
@@ -170,7 +175,7 @@ func (o *Olm) handleWgPeerRelay(msg websocket.WSMessage) {
|
||||
return
|
||||
}
|
||||
|
||||
primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint)
|
||||
primaryRelay, err := util.ResolveDomainUpstream(relayData.RelayEndpoint, o.tunnelConfig.PublicDNS)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve primary relay endpoint: %v", err)
|
||||
return
|
||||
@@ -203,7 +208,7 @@ func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) {
|
||||
return
|
||||
}
|
||||
|
||||
primaryRelay, err := util.ResolveDomain(relayData.Endpoint)
|
||||
primaryRelay, err := util.ResolveDomainUpstream(relayData.Endpoint, o.tunnelConfig.PublicDNS)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
|
||||
}
|
||||
|
||||
@@ -61,6 +61,7 @@ type TunnelConfig struct {
|
||||
MTU int
|
||||
DNS string
|
||||
UpstreamDNS []string
|
||||
PublicDNS []string
|
||||
InterfaceName string
|
||||
|
||||
// Advanced
|
||||
|
||||
@@ -32,7 +32,8 @@ type PeerManagerConfig struct {
|
||||
SharedBind *bind.SharedBind
|
||||
// WSClient is optional - if nil, relay messages won't be sent
|
||||
WSClient *websocket.Client
|
||||
APIServer *api.API
|
||||
APIServer *api.API
|
||||
PublicDNS []string
|
||||
}
|
||||
|
||||
type PeerManager struct {
|
||||
@@ -50,7 +51,8 @@ type PeerManager struct {
|
||||
// key is the CIDR string, value is a set of siteIds that want this IP
|
||||
allowedIPClaims map[string]map[int]bool
|
||||
APIServer *api.API
|
||||
|
||||
publicDNS []string
|
||||
|
||||
PersistentKeepalive int
|
||||
}
|
||||
|
||||
@@ -65,6 +67,7 @@ func NewPeerManager(config PeerManagerConfig) *PeerManager {
|
||||
allowedIPOwners: make(map[string]int),
|
||||
allowedIPClaims: make(map[string]map[int]bool),
|
||||
APIServer: config.APIServer,
|
||||
publicDNS: config.PublicDNS,
|
||||
}
|
||||
|
||||
// Create the peer monitor
|
||||
@@ -74,6 +77,7 @@ func NewPeerManager(config PeerManagerConfig) *PeerManager {
|
||||
config.LocalIP,
|
||||
config.SharedBind,
|
||||
config.APIServer,
|
||||
config.PublicDNS,
|
||||
)
|
||||
|
||||
return pm
|
||||
@@ -129,7 +133,7 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
|
||||
wgConfig := siteConfig
|
||||
wgConfig.AllowedIps = ownedIPs
|
||||
|
||||
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil {
|
||||
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive, pm.publicDNS); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -270,7 +274,7 @@ func (pm *PeerManager) RemovePeer(siteId int) error {
|
||||
ownedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
|
||||
wgConfig := promotedPeer
|
||||
wgConfig.AllowedIps = ownedIPs
|
||||
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil {
|
||||
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive, pm.publicDNS); err != nil {
|
||||
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
|
||||
}
|
||||
}
|
||||
@@ -346,7 +350,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
|
||||
wgConfig := siteConfig
|
||||
wgConfig.AllowedIps = ownedIPs
|
||||
|
||||
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil {
|
||||
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive, pm.publicDNS); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -356,7 +360,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
|
||||
promotedOwnedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
|
||||
promotedWgConfig := promotedPeer
|
||||
promotedWgConfig.AllowedIps = promotedOwnedIPs
|
||||
if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil {
|
||||
if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive, pm.publicDNS); err != nil {
|
||||
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,6 +34,7 @@ type PeerMonitor struct {
|
||||
timeout time.Duration
|
||||
maxAttempts int
|
||||
wsClient *websocket.Client
|
||||
publicDNS []string
|
||||
|
||||
// Netstack fields
|
||||
middleDev *middleDevice.MiddleDevice
|
||||
@@ -82,7 +83,7 @@ type PeerMonitor struct {
|
||||
}
|
||||
|
||||
// NewPeerMonitor creates a new peer monitor with the given callback
|
||||
func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind, apiServer *api.API) *PeerMonitor {
|
||||
func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind, apiServer *api.API, publicDNS []string) *PeerMonitor {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pm := &PeerMonitor{
|
||||
monitors: make(map[int]*Client),
|
||||
@@ -91,6 +92,7 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
|
||||
wsClient: wsClient,
|
||||
middleDev: middleDev,
|
||||
localIP: localIP,
|
||||
publicDNS: publicDNS,
|
||||
activePorts: make(map[uint16]bool),
|
||||
nsCtx: ctx,
|
||||
nsCancel: cancel,
|
||||
@@ -124,7 +126,7 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
|
||||
|
||||
// Initialize holepunch tester if sharedBind is available
|
||||
if sharedBind != nil {
|
||||
pm.holepunchTester = holepunch.NewHolepunchTester(sharedBind)
|
||||
pm.holepunchTester = holepunch.NewHolepunchTester(sharedBind, publicDNS)
|
||||
}
|
||||
|
||||
return pm
|
||||
|
||||
@@ -11,14 +11,14 @@ import (
|
||||
)
|
||||
|
||||
// ConfigurePeer sets up or updates a peer within the WireGuard device
|
||||
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool, persistentKeepalive int) error {
|
||||
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool, persistentKeepalive int, publicDNS []string) error {
|
||||
var endpoint string
|
||||
if relay && siteConfig.RelayEndpoint != "" {
|
||||
endpoint = formatEndpoint(siteConfig.RelayEndpoint)
|
||||
} else {
|
||||
endpoint = formatEndpoint(siteConfig.Endpoint)
|
||||
}
|
||||
siteHost, err := util.ResolveDomain(endpoint)
|
||||
siteHost, err := util.ResolveDomainUpstream(endpoint, publicDNS)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user