Compare commits

..

9 Commits

Author SHA1 Message Date
Owen
8ce40501c3 Remove extra restore function 2026-03-11 17:16:33 -07:00
Owen
fc92ff355a Actually pull the upstream from the dns var 2026-03-11 16:54:53 -07:00
Owen
c6a486a0a6 Add hardcoded public dns 2026-03-11 16:47:01 -07:00
Owen
9c0e37eddb Send token 2026-02-24 19:47:30 -08:00
Owen
5527bff671 Merge branch 'dev' 2026-02-06 15:17:21 -08:00
Owen
af973b2440 Support prt records 2026-02-06 15:17:01 -08:00
Owen
dd9bff9a4b Fix peer names clearing 2026-02-02 18:03:29 -08:00
Owen
1be5e454ba Default override dns to true
Ref #59
2026-02-02 10:03:22 -08:00
Owen
4a25a0d413 Dont go unregistered when low power mode
Former-commit-id: 0938564038
2026-01-31 16:58:05 -08:00
15 changed files with 704 additions and 44 deletions

View File

@@ -272,9 +272,6 @@ func (s *API) SetConnectionStatus(isConnected bool) {
if isConnected { if isConnected {
s.connectedAt = time.Now() s.connectedAt = time.Now()
} else {
// Clear peer statuses when disconnected
s.peerStatuses = make(map[int]*PeerStatus)
} }
} }

View File

@@ -89,7 +89,7 @@ func DefaultConfig() *OlmConfig {
PingInterval: "3s", PingInterval: "3s",
PingTimeout: "5s", PingTimeout: "5s",
DisableHolepunch: false, DisableHolepunch: false,
OverrideDNS: false, OverrideDNS: true,
TunnelDNS: false, TunnelDNS: false,
// DoNotCreateNewClient: false, // DoNotCreateNewClient: false,
sources: make(map[string]string), sources: make(map[string]string),

View File

@@ -380,7 +380,7 @@ func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clie
// Check if we have local records for this query // Check if we have local records for this query
var response *dns.Msg var response *dns.Msg
if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA { if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA || question.Qtype == dns.TypePTR {
response = p.checkLocalRecords(msg, question) response = p.checkLocalRecords(msg, question)
} }
@@ -410,6 +410,34 @@ func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clie
// checkLocalRecords checks if we have local records for the query // checkLocalRecords checks if we have local records for the query
func (p *DNSProxy) checkLocalRecords(query *dns.Msg, question dns.Question) *dns.Msg { func (p *DNSProxy) checkLocalRecords(query *dns.Msg, question dns.Question) *dns.Msg {
// Handle PTR queries
if question.Qtype == dns.TypePTR {
if ptrDomain, ok := p.recordStore.GetPTRRecord(question.Name); ok {
logger.Debug("Found local PTR record for %s -> %s", question.Name, ptrDomain)
// Create response message
response := new(dns.Msg)
response.SetReply(query)
response.Authoritative = true
// Add PTR answer record
rr := &dns.PTR{
Hdr: dns.RR_Header{
Name: question.Name,
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 300, // 5 minutes
},
Ptr: ptrDomain,
}
response.Answer = append(response.Answer, rr)
return response
}
return nil
}
// Handle A and AAAA queries
var recordType RecordType var recordType RecordType
if question.Qtype == dns.TypeA { if question.Qtype == dns.TypeA {
recordType = RecordTypeA recordType = RecordTypeA

View File

@@ -1,6 +1,7 @@
package dns package dns
import ( import (
"fmt"
"net" "net"
"strings" "strings"
"sync" "sync"
@@ -14,15 +15,17 @@ type RecordType uint16
const ( const (
RecordTypeA RecordType = RecordType(dns.TypeA) RecordTypeA RecordType = RecordType(dns.TypeA)
RecordTypeAAAA RecordType = RecordType(dns.TypeAAAA) RecordTypeAAAA RecordType = RecordType(dns.TypeAAAA)
RecordTypePTR RecordType = RecordType(dns.TypePTR)
) )
// DNSRecordStore manages local DNS records for A and AAAA queries // DNSRecordStore manages local DNS records for A, AAAA, and PTR queries
type DNSRecordStore struct { type DNSRecordStore struct {
mu sync.RWMutex mu sync.RWMutex
aRecords map[string][]net.IP // domain -> list of IPv4 addresses aRecords map[string][]net.IP // domain -> list of IPv4 addresses
aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses
aWildcards map[string][]net.IP // wildcard pattern -> list of IPv4 addresses aWildcards map[string][]net.IP // wildcard pattern -> list of IPv4 addresses
aaaaWildcards map[string][]net.IP // wildcard pattern -> list of IPv6 addresses aaaaWildcards map[string][]net.IP // wildcard pattern -> list of IPv6 addresses
ptrRecords map[string]string // IP address string -> domain name
} }
// NewDNSRecordStore creates a new DNS record store // NewDNSRecordStore creates a new DNS record store
@@ -32,6 +35,7 @@ func NewDNSRecordStore() *DNSRecordStore {
aaaaRecords: make(map[string][]net.IP), aaaaRecords: make(map[string][]net.IP),
aWildcards: make(map[string][]net.IP), aWildcards: make(map[string][]net.IP),
aaaaWildcards: make(map[string][]net.IP), aaaaWildcards: make(map[string][]net.IP),
ptrRecords: make(map[string]string),
} }
} }
@@ -39,6 +43,7 @@ func NewDNSRecordStore() *DNSRecordStore {
// domain should be in FQDN format (e.g., "example.com.") // domain should be in FQDN format (e.g., "example.com.")
// domain can contain wildcards: * (0+ chars) and ? (exactly 1 char) // domain can contain wildcards: * (0+ chars) and ? (exactly 1 char)
// ip should be a valid IPv4 or IPv6 address // ip should be a valid IPv4 or IPv6 address
// Automatically adds a corresponding PTR record for non-wildcard domains
func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error { func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@@ -60,6 +65,8 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
s.aWildcards[domain] = append(s.aWildcards[domain], ip) s.aWildcards[domain] = append(s.aWildcards[domain], ip)
} else { } else {
s.aRecords[domain] = append(s.aRecords[domain], ip) s.aRecords[domain] = append(s.aRecords[domain], ip)
// Automatically add PTR record for non-wildcard domains
s.ptrRecords[ip.String()] = domain
} }
} else if ip.To16() != nil { } else if ip.To16() != nil {
// IPv6 address // IPv6 address
@@ -67,6 +74,8 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
s.aaaaWildcards[domain] = append(s.aaaaWildcards[domain], ip) s.aaaaWildcards[domain] = append(s.aaaaWildcards[domain], ip)
} else { } else {
s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip) s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip)
// Automatically add PTR record for non-wildcard domains
s.ptrRecords[ip.String()] = domain
} }
} else { } else {
return &net.ParseError{Type: "IP address", Text: ip.String()} return &net.ParseError{Type: "IP address", Text: ip.String()}
@@ -75,8 +84,30 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
return nil return nil
} }
// AddPTRRecord adds a PTR record mapping an IP address to a domain name
// ip should be a valid IPv4 or IPv6 address
// domain should be in FQDN format (e.g., "example.com.")
func (s *DNSRecordStore) AddPTRRecord(ip net.IP, domain string) error {
s.mu.Lock()
defer s.mu.Unlock()
// Ensure domain ends with a dot (FQDN format)
if len(domain) == 0 || domain[len(domain)-1] != '.' {
domain = domain + "."
}
// Normalize domain to lowercase FQDN
domain = strings.ToLower(dns.Fqdn(domain))
// Store PTR record using IP string as key
s.ptrRecords[ip.String()] = domain
return nil
}
// RemoveRecord removes a specific DNS record mapping // RemoveRecord removes a specific DNS record mapping
// If ip is nil, removes all records for the domain (including wildcards) // If ip is nil, removes all records for the domain (including wildcards)
// Automatically removes corresponding PTR records for non-wildcard domains
func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) { func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@@ -98,6 +129,23 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
delete(s.aWildcards, domain) delete(s.aWildcards, domain)
delete(s.aaaaWildcards, domain) delete(s.aaaaWildcards, domain)
} else { } else {
// For non-wildcard domains, remove PTR records for all IPs
if ips, ok := s.aRecords[domain]; ok {
for _, ipAddr := range ips {
// Only remove PTR if it points to this domain
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ipAddr.String())
}
}
}
if ips, ok := s.aaaaRecords[domain]; ok {
for _, ipAddr := range ips {
// Only remove PTR if it points to this domain
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ipAddr.String())
}
}
}
delete(s.aRecords, domain) delete(s.aRecords, domain)
delete(s.aaaaRecords, domain) delete(s.aaaaRecords, domain)
} }
@@ -119,6 +167,10 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
if len(s.aRecords[domain]) == 0 { if len(s.aRecords[domain]) == 0 {
delete(s.aRecords, domain) delete(s.aRecords, domain)
} }
// Automatically remove PTR record if it points to this domain
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ip.String())
}
} }
} }
} else if ip.To16() != nil { } else if ip.To16() != nil {
@@ -136,11 +188,23 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
if len(s.aaaaRecords[domain]) == 0 { if len(s.aaaaRecords[domain]) == 0 {
delete(s.aaaaRecords, domain) delete(s.aaaaRecords, domain)
} }
// Automatically remove PTR record if it points to this domain
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ip.String())
}
} }
} }
} }
} }
// RemovePTRRecord removes a PTR record for an IP address
func (s *DNSRecordStore) RemovePTRRecord(ip net.IP) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.ptrRecords, ip.String())
}
// GetRecords returns all IP addresses for a domain and record type // GetRecords returns all IP addresses for a domain and record type
// First checks for exact matches, then checks wildcard patterns // First checks for exact matches, then checks wildcard patterns
func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP { func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP {
@@ -198,6 +262,26 @@ func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.
return records return records
} }
// GetPTRRecord returns the domain name for a PTR record query
// domain should be in reverse DNS format (e.g., "1.0.0.127.in-addr.arpa.")
func (s *DNSRecordStore) GetPTRRecord(domain string) (string, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
// Convert reverse DNS format to IP address
ip := reverseDNSToIP(domain)
if ip == nil {
return "", false
}
// Look up the PTR record
if ptrDomain, ok := s.ptrRecords[ip.String()]; ok {
return ptrDomain, true
}
return "", false
}
// HasRecord checks if a domain has any records of the specified type // HasRecord checks if a domain has any records of the specified type
// Checks both exact matches and wildcard patterns // Checks both exact matches and wildcard patterns
func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool { func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
@@ -235,6 +319,21 @@ func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
return false return false
} }
// HasPTRRecord checks if a PTR record exists for the given reverse DNS domain
func (s *DNSRecordStore) HasPTRRecord(domain string) bool {
s.mu.RLock()
defer s.mu.RUnlock()
// Convert reverse DNS format to IP address
ip := reverseDNSToIP(domain)
if ip == nil {
return false
}
_, ok := s.ptrRecords[ip.String()]
return ok
}
// Clear removes all records from the store // Clear removes all records from the store
func (s *DNSRecordStore) Clear() { func (s *DNSRecordStore) Clear() {
s.mu.Lock() s.mu.Lock()
@@ -244,6 +343,7 @@ func (s *DNSRecordStore) Clear() {
s.aaaaRecords = make(map[string][]net.IP) s.aaaaRecords = make(map[string][]net.IP)
s.aWildcards = make(map[string][]net.IP) s.aWildcards = make(map[string][]net.IP)
s.aaaaWildcards = make(map[string][]net.IP) s.aaaaWildcards = make(map[string][]net.IP)
s.ptrRecords = make(map[string]string)
} }
// removeIP is a helper function to remove a specific IP from a slice // removeIP is a helper function to remove a specific IP from a slice
@@ -323,3 +423,75 @@ func matchWildcardInternal(pattern, domain string, pi, di int) bool {
return matchWildcardInternal(pattern, domain, pi+1, di+1) return matchWildcardInternal(pattern, domain, pi+1, di+1)
} }
// reverseDNSToIP converts a reverse DNS query name to an IP address
// Supports both IPv4 (in-addr.arpa) and IPv6 (ip6.arpa) formats
func reverseDNSToIP(domain string) net.IP {
// Normalize to lowercase and ensure FQDN
domain = strings.ToLower(dns.Fqdn(domain))
// Check for IPv4 reverse DNS (in-addr.arpa)
if strings.HasSuffix(domain, ".in-addr.arpa.") {
// Remove the suffix
ipPart := strings.TrimSuffix(domain, ".in-addr.arpa.")
// Split by dots and reverse
parts := strings.Split(ipPart, ".")
if len(parts) != 4 {
return nil
}
// Reverse the octets
reversed := make([]string, 4)
for i := 0; i < 4; i++ {
reversed[i] = parts[3-i]
}
// Parse as IP
return net.ParseIP(strings.Join(reversed, "."))
}
// Check for IPv6 reverse DNS (ip6.arpa)
if strings.HasSuffix(domain, ".ip6.arpa.") {
// Remove the suffix
ipPart := strings.TrimSuffix(domain, ".ip6.arpa.")
// Split by dots and reverse
parts := strings.Split(ipPart, ".")
if len(parts) != 32 {
return nil
}
// Reverse the nibbles and group into 16-bit hex values
reversed := make([]string, 32)
for i := 0; i < 32; i++ {
reversed[i] = parts[31-i]
}
// Join into IPv6 format (groups of 4 nibbles separated by colons)
var ipv6Parts []string
for i := 0; i < 32; i += 4 {
ipv6Parts = append(ipv6Parts, reversed[i]+reversed[i+1]+reversed[i+2]+reversed[i+3])
}
// Parse as IP
return net.ParseIP(strings.Join(ipv6Parts, ":"))
}
return nil
}
// IPToReverseDNS converts an IP address to reverse DNS format
// Returns the domain name for PTR queries (e.g., "1.0.0.127.in-addr.arpa.")
func IPToReverseDNS(ip net.IP) string {
if ip4 := ip.To4(); ip4 != nil {
// IPv4: reverse octets and append .in-addr.arpa.
return dns.Fqdn(fmt.Sprintf("%d.%d.%d.%d.in-addr.arpa",
ip4[3], ip4[2], ip4[1], ip4[0]))
}
if ip6 := ip.To16(); ip6 != nil && ip.To4() == nil {
// IPv6: expand to 32 nibbles, reverse, and append .ip6.arpa.
var nibbles []string
for i := 15; i >= 0; i-- {
nibbles = append(nibbles, fmt.Sprintf("%x", ip6[i]&0x0f))
nibbles = append(nibbles, fmt.Sprintf("%x", ip6[i]>>4))
}
return dns.Fqdn(strings.Join(nibbles, ".") + ".ip6.arpa")
}
return ""
}

View File

@@ -413,3 +413,452 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
t.Error("Expected HasRecord to return true for mixed case wildcard match") t.Error("Expected HasRecord to return true for mixed case wildcard match")
} }
} }
func TestPTRRecordIPv4(t *testing.T) {
store := NewDNSRecordStore()
// Add PTR record for IPv4
ip := net.ParseIP("192.168.1.1")
domain := "host.example.com."
err := store.AddPTRRecord(ip, domain)
if err != nil {
t.Fatalf("Failed to add PTR record: %v", err)
}
// Test reverse DNS lookup
reverseDomain := "1.1.168.192.in-addr.arpa."
result, ok := store.GetPTRRecord(reverseDomain)
if !ok {
t.Error("Expected PTR record to be found")
}
if result != domain {
t.Errorf("Expected domain %q, got %q", domain, result)
}
// Test HasPTRRecord
if !store.HasPTRRecord(reverseDomain) {
t.Error("Expected HasPTRRecord to return true")
}
// Test non-existent PTR record
_, ok = store.GetPTRRecord("2.1.168.192.in-addr.arpa.")
if ok {
t.Error("Expected PTR record not to be found for different IP")
}
}
func TestPTRRecordIPv6(t *testing.T) {
store := NewDNSRecordStore()
// Add PTR record for IPv6
ip := net.ParseIP("2001:db8::1")
domain := "ipv6host.example.com."
err := store.AddPTRRecord(ip, domain)
if err != nil {
t.Fatalf("Failed to add PTR record: %v", err)
}
// Test reverse DNS lookup
// 2001:db8::1 = 2001:0db8:0000:0000:0000:0000:0000:0001
// Reverse: 1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.
reverseDomain := "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa."
result, ok := store.GetPTRRecord(reverseDomain)
if !ok {
t.Error("Expected IPv6 PTR record to be found")
}
if result != domain {
t.Errorf("Expected domain %q, got %q", domain, result)
}
// Test HasPTRRecord
if !store.HasPTRRecord(reverseDomain) {
t.Error("Expected HasPTRRecord to return true for IPv6")
}
}
func TestRemovePTRRecord(t *testing.T) {
store := NewDNSRecordStore()
// Add PTR record
ip := net.ParseIP("10.0.0.1")
domain := "test.example.com."
err := store.AddPTRRecord(ip, domain)
if err != nil {
t.Fatalf("Failed to add PTR record: %v", err)
}
// Verify it exists
reverseDomain := "1.0.0.10.in-addr.arpa."
_, ok := store.GetPTRRecord(reverseDomain)
if !ok {
t.Error("Expected PTR record to exist before removal")
}
// Remove PTR record
store.RemovePTRRecord(ip)
// Verify it's gone
_, ok = store.GetPTRRecord(reverseDomain)
if ok {
t.Error("Expected PTR record to be removed")
}
// Test HasPTRRecord after removal
if store.HasPTRRecord(reverseDomain) {
t.Error("Expected HasPTRRecord to return false after removal")
}
}
func TestIPToReverseDNS(t *testing.T) {
tests := []struct {
name string
ip string
expected string
}{
{
name: "IPv4 simple",
ip: "192.168.1.1",
expected: "1.1.168.192.in-addr.arpa.",
},
{
name: "IPv4 localhost",
ip: "127.0.0.1",
expected: "1.0.0.127.in-addr.arpa.",
},
{
name: "IPv4 with zeros",
ip: "10.0.0.1",
expected: "1.0.0.10.in-addr.arpa.",
},
{
name: "IPv6 simple",
ip: "2001:db8::1",
expected: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
},
{
name: "IPv6 localhost",
ip: "::1",
expected: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ip := net.ParseIP(tt.ip)
if ip == nil {
t.Fatalf("Failed to parse IP: %s", tt.ip)
}
result := IPToReverseDNS(ip)
if result != tt.expected {
t.Errorf("IPToReverseDNS(%s) = %q, want %q", tt.ip, result, tt.expected)
}
})
}
}
func TestReverseDNSToIP(t *testing.T) {
tests := []struct {
name string
reverseDNS string
expectedIP string
shouldMatch bool
}{
{
name: "IPv4 simple",
reverseDNS: "1.1.168.192.in-addr.arpa.",
expectedIP: "192.168.1.1",
shouldMatch: true,
},
{
name: "IPv4 localhost",
reverseDNS: "1.0.0.127.in-addr.arpa.",
expectedIP: "127.0.0.1",
shouldMatch: true,
},
{
name: "IPv6 simple",
reverseDNS: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
expectedIP: "2001:db8::1",
shouldMatch: true,
},
{
name: "Invalid IPv4 format",
reverseDNS: "1.1.168.in-addr.arpa.",
expectedIP: "",
shouldMatch: false,
},
{
name: "Invalid IPv6 format",
reverseDNS: "1.0.0.0.ip6.arpa.",
expectedIP: "",
shouldMatch: false,
},
{
name: "Not a reverse DNS domain",
reverseDNS: "example.com.",
expectedIP: "",
shouldMatch: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := reverseDNSToIP(tt.reverseDNS)
if tt.shouldMatch {
if result == nil {
t.Errorf("reverseDNSToIP(%q) returned nil, expected IP", tt.reverseDNS)
return
}
expectedIP := net.ParseIP(tt.expectedIP)
if !result.Equal(expectedIP) {
t.Errorf("reverseDNSToIP(%q) = %v, want %v", tt.reverseDNS, result, expectedIP)
}
} else {
if result != nil {
t.Errorf("reverseDNSToIP(%q) = %v, expected nil", tt.reverseDNS, result)
}
}
})
}
}
func TestPTRRecordCaseInsensitive(t *testing.T) {
store := NewDNSRecordStore()
// Add PTR record with mixed case domain
ip := net.ParseIP("192.168.1.1")
domain := "MyHost.Example.Com"
err := store.AddPTRRecord(ip, domain)
if err != nil {
t.Fatalf("Failed to add PTR record: %v", err)
}
// Test lookup with different cases in reverse DNS format
reverseDomain := "1.1.168.192.in-addr.arpa."
result, ok := store.GetPTRRecord(reverseDomain)
if !ok {
t.Error("Expected PTR record to be found")
}
// Domain should be normalized to lowercase
if result != "myhost.example.com." {
t.Errorf("Expected normalized domain %q, got %q", "myhost.example.com.", result)
}
// Test with uppercase reverse DNS
reverseDomainUpper := "1.1.168.192.IN-ADDR.ARPA."
result, ok = store.GetPTRRecord(reverseDomainUpper)
if !ok {
t.Error("Expected PTR record to be found with uppercase reverse DNS")
}
if result != "myhost.example.com." {
t.Errorf("Expected normalized domain %q, got %q", "myhost.example.com.", result)
}
}
func TestClearPTRRecords(t *testing.T) {
store := NewDNSRecordStore()
// Add some PTR records
ip1 := net.ParseIP("192.168.1.1")
ip2 := net.ParseIP("192.168.1.2")
store.AddPTRRecord(ip1, "host1.example.com.")
store.AddPTRRecord(ip2, "host2.example.com.")
// Add some A records too
store.AddRecord("test.example.com.", net.ParseIP("10.0.0.1"))
// Verify PTR records exist
if !store.HasPTRRecord("1.1.168.192.in-addr.arpa.") {
t.Error("Expected PTR record to exist before clear")
}
// Clear all records
store.Clear()
// Verify PTR records are gone
if store.HasPTRRecord("1.1.168.192.in-addr.arpa.") {
t.Error("Expected PTR record to be cleared")
}
if store.HasPTRRecord("2.1.168.192.in-addr.arpa.") {
t.Error("Expected PTR record to be cleared")
}
// Verify A records are also gone
if store.HasRecord("test.example.com.", RecordTypeA) {
t.Error("Expected A record to be cleared")
}
}
func TestAutomaticPTRRecordOnAdd(t *testing.T) {
store := NewDNSRecordStore()
// Add an A record - should automatically add PTR record
domain := "host.example.com."
ip := net.ParseIP("192.168.1.100")
err := store.AddRecord(domain, ip)
if err != nil {
t.Fatalf("Failed to add A record: %v", err)
}
// Verify PTR record was automatically created
reverseDomain := "100.1.168.192.in-addr.arpa."
result, ok := store.GetPTRRecord(reverseDomain)
if !ok {
t.Error("Expected PTR record to be automatically created")
}
if result != domain {
t.Errorf("Expected PTR to point to %q, got %q", domain, result)
}
// Add AAAA record - should also automatically add PTR record
domain6 := "ipv6host.example.com."
ip6 := net.ParseIP("2001:db8::1")
err = store.AddRecord(domain6, ip6)
if err != nil {
t.Fatalf("Failed to add AAAA record: %v", err)
}
// Verify IPv6 PTR record was automatically created
reverseDomain6 := "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa."
result6, ok := store.GetPTRRecord(reverseDomain6)
if !ok {
t.Error("Expected IPv6 PTR record to be automatically created")
}
if result6 != domain6 {
t.Errorf("Expected PTR to point to %q, got %q", domain6, result6)
}
}
func TestAutomaticPTRRecordOnRemove(t *testing.T) {
store := NewDNSRecordStore()
// Add an A record (with automatic PTR)
domain := "host.example.com."
ip := net.ParseIP("192.168.1.100")
store.AddRecord(domain, ip)
// Verify PTR exists
reverseDomain := "100.1.168.192.in-addr.arpa."
if !store.HasPTRRecord(reverseDomain) {
t.Error("Expected PTR record to exist after adding A record")
}
// Remove the A record
store.RemoveRecord(domain, ip)
// Verify PTR was automatically removed
if store.HasPTRRecord(reverseDomain) {
t.Error("Expected PTR record to be automatically removed")
}
// Verify A record is also gone
ips := store.GetRecords(domain, RecordTypeA)
if len(ips) != 0 {
t.Errorf("Expected A record to be removed, got %d records", len(ips))
}
}
func TestAutomaticPTRRecordOnRemoveAll(t *testing.T) {
store := NewDNSRecordStore()
// Add multiple IPs for the same domain
domain := "host.example.com."
ip1 := net.ParseIP("192.168.1.100")
ip2 := net.ParseIP("192.168.1.101")
store.AddRecord(domain, ip1)
store.AddRecord(domain, ip2)
// Verify both PTR records exist
reverseDomain1 := "100.1.168.192.in-addr.arpa."
reverseDomain2 := "101.1.168.192.in-addr.arpa."
if !store.HasPTRRecord(reverseDomain1) {
t.Error("Expected first PTR record to exist")
}
if !store.HasPTRRecord(reverseDomain2) {
t.Error("Expected second PTR record to exist")
}
// Remove all records for the domain
store.RemoveRecord(domain, nil)
// Verify both PTR records were removed
if store.HasPTRRecord(reverseDomain1) {
t.Error("Expected first PTR record to be removed")
}
if store.HasPTRRecord(reverseDomain2) {
t.Error("Expected second PTR record to be removed")
}
}
func TestNoPTRForWildcardRecords(t *testing.T) {
store := NewDNSRecordStore()
// Add wildcard record - should NOT create PTR record
domain := "*.example.com."
ip := net.ParseIP("192.168.1.100")
err := store.AddRecord(domain, ip)
if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err)
}
// Verify no PTR record was created
reverseDomain := "100.1.168.192.in-addr.arpa."
_, ok := store.GetPTRRecord(reverseDomain)
if ok {
t.Error("Expected no PTR record for wildcard domain")
}
// Verify wildcard A record exists
if !store.HasRecord("host.example.com.", RecordTypeA) {
t.Error("Expected wildcard A record to exist")
}
}
func TestPTRRecordOverwrite(t *testing.T) {
store := NewDNSRecordStore()
// Add first domain with IP
domain1 := "host1.example.com."
ip := net.ParseIP("192.168.1.100")
store.AddRecord(domain1, ip)
// Verify PTR points to first domain
reverseDomain := "100.1.168.192.in-addr.arpa."
result, ok := store.GetPTRRecord(reverseDomain)
if !ok {
t.Fatal("Expected PTR record to exist")
}
if result != domain1 {
t.Errorf("Expected PTR to point to %q, got %q", domain1, result)
}
// Add second domain with same IP - should overwrite PTR
domain2 := "host2.example.com."
store.AddRecord(domain2, ip)
// Verify PTR now points to second domain (last one added)
result, ok = store.GetPTRRecord(reverseDomain)
if !ok {
t.Fatal("Expected PTR record to still exist")
}
if result != domain2 {
t.Errorf("Expected PTR to point to %q (overwritten), got %q", domain2, result)
}
// Remove first domain - PTR should remain pointing to second domain
store.RemoveRecord(domain1, ip)
result, ok = store.GetPTRRecord(reverseDomain)
if !ok {
t.Error("Expected PTR record to still exist after removing first domain")
}
if result != domain2 {
t.Errorf("Expected PTR to still point to %q, got %q", domain2, result)
}
// Remove second domain - PTR should now be gone
store.RemoveRecord(domain2, ip)
_, ok = store.GetPTRRecord(reverseDomain)
if ok {
t.Error("Expected PTR record to be removed after removing second domain")
}
}

14
go.mod
View File

@@ -1,6 +1,6 @@
module github.com/fosrl/olm module github.com/fosrl/olm
go 1.25 go 1.25.0
require ( require (
github.com/Microsoft/go-winio v0.6.2 github.com/Microsoft/go-winio v0.6.2
@@ -8,7 +8,7 @@ require (
github.com/godbus/dbus/v5 v5.2.2 github.com/godbus/dbus/v5 v5.2.2
github.com/gorilla/websocket v1.5.3 github.com/gorilla/websocket v1.5.3
github.com/miekg/dns v1.1.70 github.com/miekg/dns v1.1.70
golang.org/x/sys v0.40.0 golang.org/x/sys v0.41.0
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c
@@ -20,16 +20,16 @@ require (
github.com/google/go-cmp v0.7.0 // indirect github.com/google/go-cmp v0.7.0 // indirect
github.com/vishvananda/netlink v1.3.1 // indirect github.com/vishvananda/netlink v1.3.1 // indirect
github.com/vishvananda/netns v0.0.5 // indirect github.com/vishvananda/netns v0.0.5 // indirect
golang.org/x/crypto v0.46.0 // indirect golang.org/x/crypto v0.48.0 // indirect
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect
golang.org/x/mod v0.31.0 // indirect golang.org/x/mod v0.32.0 // indirect
golang.org/x/net v0.48.0 // indirect golang.org/x/net v0.51.0 // indirect
golang.org/x/sync v0.19.0 // indirect golang.org/x/sync v0.19.0 // indirect
golang.org/x/time v0.12.0 // indirect golang.org/x/time v0.12.0 // indirect
golang.org/x/tools v0.40.0 // indirect golang.org/x/tools v0.41.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
) )
// To be used ONLY for local development // To be used ONLY for local development
// replace github.com/fosrl/newt => ../newt replace github.com/fosrl/newt => ../newt

22
go.sum
View File

@@ -1,7 +1,5 @@
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/fosrl/newt v1.9.0 h1:66eJMo6fA+YcBTbddxTfNJXNQo1WWKzmn6zPRP5kSDE=
github.com/fosrl/newt v1.9.0/go.mod h1:d1+yYMnKqg4oLqAM9zdbjthjj2FQEVouiACjqU468ck=
github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ= github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ=
github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c= github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
@@ -16,24 +14,24 @@ github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4= github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY= github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0= golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0=
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0= golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0=
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI= golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg= golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA= golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc= golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A= golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A=

View File

@@ -168,6 +168,7 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) {
SharedBind: o.sharedBind, SharedBind: o.sharedBind,
WSClient: o.websocket, WSClient: o.websocket,
APIServer: o.apiServer, APIServer: o.apiServer,
PublicDNS: o.tunnelConfig.PublicDNS,
}) })
for i := range wgData.Sites { for i := range wgData.Sites {

View File

@@ -31,7 +31,7 @@ type Olm struct {
privateKey wgtypes.Key privateKey wgtypes.Key
logFile *os.File logFile *os.File
registered bool registered bool
tunnelRunning bool tunnelRunning bool
uapiListener net.Listener uapiListener net.Listener
@@ -111,7 +111,7 @@ func (o *Olm) initTunnelInfo(clientID string) error {
logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount()) logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount())
// Create the holepunch manager // Create the holepunch manager
o.holePunchManager = holepunch.NewManager(sharedBind, clientID, "olm", privateKey.PublicKey().String()) o.holePunchManager = holepunch.NewManager(sharedBind, clientID, "olm", privateKey.PublicKey().String(), o.tunnelConfig.PublicDNS)
return nil return nil
} }
@@ -222,7 +222,7 @@ func (o *Olm) registerAPICallbacks() {
tunnelConfig.MTU = 1420 tunnelConfig.MTU = 1420
} }
if req.DNS == "" { if req.DNS == "" {
tunnelConfig.DNS = "9.9.9.9" tunnelConfig.DNS = "8.8.8.8"
} }
// DNSProxyIP has no default - it must be provided if DNS proxy is desired // DNSProxyIP has no default - it must be provided if DNS proxy is desired
// UpstreamDNS defaults to 8.8.8.8 if not provided // UpstreamDNS defaults to 8.8.8.8 if not provided
@@ -292,16 +292,23 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
logger.Info("Tunnel already running") logger.Info("Tunnel already running")
return return
} }
// debug print out the whole config // debug print out the whole config
logger.Debug("Starting tunnel with config: %+v", config) logger.Debug("Starting tunnel with config: %+v", config)
o.tunnelRunning = true // Also set it here in case it is called externally o.tunnelRunning = true // Also set it here in case it is called externally
o.tunnelConfig = config o.tunnelConfig = config
// TODO: we are hardcoding this for now but we should really pull it from the current config of the system
if o.tunnelConfig.DNS != "" {
o.tunnelConfig.PublicDNS = []string{o.tunnelConfig.DNS + ":53"}
} else {
o.tunnelConfig.PublicDNS = []string{"8.8.8.8:53"}
}
// Reset terminated status when tunnel starts // Reset terminated status when tunnel starts
o.apiServer.SetTerminated(false) o.apiServer.SetTerminated(false)
fingerprint := config.InitialFingerprint fingerprint := config.InitialFingerprint
if fingerprint == nil { if fingerprint == nil {
fingerprint = make(map[string]any) fingerprint = make(map[string]any)
@@ -313,7 +320,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
} }
o.SetFingerprint(fingerprint) o.SetFingerprint(fingerprint)
o.SetPostures(postures) o.SetPostures(postures)
// Create a cancellable context for this tunnel process // Create a cancellable context for this tunnel process
tunnelCtx, cancel := context.WithCancel(o.olmCtx) tunnelCtx, cancel := context.WithCancel(o.olmCtx)
@@ -387,7 +394,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
if o.registered { if o.registered {
o.websocket.StartPingMonitor() o.websocket.StartPingMonitor()
logger.Debug("Already registered, skipping registration") logger.Debug("Already registered, skipping registration")
return nil return nil
} }

View File

@@ -170,7 +170,7 @@ func (o *Olm) handleWgPeerRelay(msg websocket.WSMessage) {
return return
} }
primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint) primaryRelay, err := util.ResolveDomainUpstream(relayData.RelayEndpoint, o.tunnelConfig.PublicDNS)
if err != nil { if err != nil {
logger.Error("Failed to resolve primary relay endpoint: %v", err) logger.Error("Failed to resolve primary relay endpoint: %v", err)
return return
@@ -203,7 +203,7 @@ func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) {
return return
} }
primaryRelay, err := util.ResolveDomain(relayData.Endpoint) primaryRelay, err := util.ResolveDomainUpstream(relayData.Endpoint, o.tunnelConfig.PublicDNS)
if err != nil { if err != nil {
logger.Warn("Failed to resolve primary relay endpoint: %v", err) logger.Warn("Failed to resolve primary relay endpoint: %v", err)
} }

View File

@@ -61,6 +61,7 @@ type TunnelConfig struct {
MTU int MTU int
DNS string DNS string
UpstreamDNS []string UpstreamDNS []string
PublicDNS []string
InterfaceName string InterfaceName string
// Advanced // Advanced

View File

@@ -32,7 +32,8 @@ type PeerManagerConfig struct {
SharedBind *bind.SharedBind SharedBind *bind.SharedBind
// WSClient is optional - if nil, relay messages won't be sent // WSClient is optional - if nil, relay messages won't be sent
WSClient *websocket.Client WSClient *websocket.Client
APIServer *api.API APIServer *api.API
PublicDNS []string
} }
type PeerManager struct { type PeerManager struct {
@@ -50,7 +51,8 @@ type PeerManager struct {
// key is the CIDR string, value is a set of siteIds that want this IP // key is the CIDR string, value is a set of siteIds that want this IP
allowedIPClaims map[string]map[int]bool allowedIPClaims map[string]map[int]bool
APIServer *api.API APIServer *api.API
publicDNS []string
PersistentKeepalive int PersistentKeepalive int
} }
@@ -65,6 +67,7 @@ func NewPeerManager(config PeerManagerConfig) *PeerManager {
allowedIPOwners: make(map[string]int), allowedIPOwners: make(map[string]int),
allowedIPClaims: make(map[string]map[int]bool), allowedIPClaims: make(map[string]map[int]bool),
APIServer: config.APIServer, APIServer: config.APIServer,
publicDNS: config.PublicDNS,
} }
// Create the peer monitor // Create the peer monitor
@@ -74,6 +77,7 @@ func NewPeerManager(config PeerManagerConfig) *PeerManager {
config.LocalIP, config.LocalIP,
config.SharedBind, config.SharedBind,
config.APIServer, config.APIServer,
config.PublicDNS,
) )
return pm return pm
@@ -129,7 +133,7 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
wgConfig := siteConfig wgConfig := siteConfig
wgConfig.AllowedIps = ownedIPs wgConfig.AllowedIps = ownedIPs
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil { if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive, pm.publicDNS); err != nil {
return err return err
} }
@@ -270,7 +274,7 @@ func (pm *PeerManager) RemovePeer(siteId int) error {
ownedIPs := pm.getOwnedAllowedIPs(promotedPeerId) ownedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
wgConfig := promotedPeer wgConfig := promotedPeer
wgConfig.AllowedIps = ownedIPs wgConfig.AllowedIps = ownedIPs
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil { if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive, pm.publicDNS); err != nil {
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err) logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
} }
} }
@@ -346,7 +350,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
wgConfig := siteConfig wgConfig := siteConfig
wgConfig.AllowedIps = ownedIPs wgConfig.AllowedIps = ownedIPs
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil { if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive, pm.publicDNS); err != nil {
return err return err
} }
@@ -356,7 +360,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
promotedOwnedIPs := pm.getOwnedAllowedIPs(promotedPeerId) promotedOwnedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
promotedWgConfig := promotedPeer promotedWgConfig := promotedPeer
promotedWgConfig.AllowedIps = promotedOwnedIPs promotedWgConfig.AllowedIps = promotedOwnedIPs
if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil { if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive, pm.publicDNS); err != nil {
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err) logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
} }
} }

View File

@@ -34,6 +34,7 @@ type PeerMonitor struct {
timeout time.Duration timeout time.Duration
maxAttempts int maxAttempts int
wsClient *websocket.Client wsClient *websocket.Client
publicDNS []string
// Netstack fields // Netstack fields
middleDev *middleDevice.MiddleDevice middleDev *middleDevice.MiddleDevice
@@ -82,7 +83,7 @@ type PeerMonitor struct {
} }
// NewPeerMonitor creates a new peer monitor with the given callback // NewPeerMonitor creates a new peer monitor with the given callback
func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind, apiServer *api.API) *PeerMonitor { func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind, apiServer *api.API, publicDNS []string) *PeerMonitor {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
pm := &PeerMonitor{ pm := &PeerMonitor{
monitors: make(map[int]*Client), monitors: make(map[int]*Client),
@@ -91,6 +92,7 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
wsClient: wsClient, wsClient: wsClient,
middleDev: middleDev, middleDev: middleDev,
localIP: localIP, localIP: localIP,
publicDNS: publicDNS,
activePorts: make(map[uint16]bool), activePorts: make(map[uint16]bool),
nsCtx: ctx, nsCtx: ctx,
nsCancel: cancel, nsCancel: cancel,
@@ -124,7 +126,7 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
// Initialize holepunch tester if sharedBind is available // Initialize holepunch tester if sharedBind is available
if sharedBind != nil { if sharedBind != nil {
pm.holepunchTester = holepunch.NewHolepunchTester(sharedBind) pm.holepunchTester = holepunch.NewHolepunchTester(sharedBind, publicDNS)
} }
return pm return pm

View File

@@ -11,14 +11,14 @@ import (
) )
// ConfigurePeer sets up or updates a peer within the WireGuard device // ConfigurePeer sets up or updates a peer within the WireGuard device
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool, persistentKeepalive int) error { func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool, persistentKeepalive int, publicDNS []string) error {
var endpoint string var endpoint string
if relay && siteConfig.RelayEndpoint != "" { if relay && siteConfig.RelayEndpoint != "" {
endpoint = formatEndpoint(siteConfig.RelayEndpoint) endpoint = formatEndpoint(siteConfig.RelayEndpoint)
} else { } else {
endpoint = formatEndpoint(siteConfig.Endpoint) endpoint = formatEndpoint(siteConfig.Endpoint)
} }
siteHost, err := util.ResolveDomain(endpoint) siteHost, err := util.ResolveDomainUpstream(endpoint, publicDNS)
if err != nil { if err != nil {
return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err) return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err)
} }

View File

@@ -388,6 +388,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
tokenData := map[string]interface{}{ tokenData := map[string]interface{}{
"olmId": c.config.ID, "olmId": c.config.ID,
"secret": c.config.Secret, "secret": c.config.Secret,
"userToken": c.config.UserToken,
"orgId": c.config.OrgID, "orgId": c.config.OrgID,
} }
jsonData, err := json.Marshal(tokenData) jsonData, err := json.Marshal(tokenData)