mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 05:56:41 +00:00
Compare commits
13 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5527bff671 | ||
|
|
af973b2440 | ||
|
|
dd9bff9a4b | ||
|
|
1be5e454ba | ||
|
|
4850b1b332 | ||
|
|
1ff74f7173 | ||
|
|
4a25a0d413 | ||
|
|
7fc3c7088e | ||
|
|
1869e70894 | ||
|
|
79783cc3dc | ||
|
|
584298e3bd | ||
|
|
f683afa647 | ||
|
|
ba2631d388 |
@@ -272,9 +272,6 @@ func (s *API) SetConnectionStatus(isConnected bool) {
|
|||||||
|
|
||||||
if isConnected {
|
if isConnected {
|
||||||
s.connectedAt = time.Now()
|
s.connectedAt = time.Now()
|
||||||
} else {
|
|
||||||
// Clear peer statuses when disconnected
|
|
||||||
s.peerStatuses = make(map[int]*PeerStatus)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -89,6 +89,7 @@ func DefaultConfig() *OlmConfig {
|
|||||||
PingInterval: "3s",
|
PingInterval: "3s",
|
||||||
PingTimeout: "5s",
|
PingTimeout: "5s",
|
||||||
DisableHolepunch: false,
|
DisableHolepunch: false,
|
||||||
|
OverrideDNS: true,
|
||||||
TunnelDNS: false,
|
TunnelDNS: false,
|
||||||
// DoNotCreateNewClient: false,
|
// DoNotCreateNewClient: false,
|
||||||
sources: make(map[string]string),
|
sources: make(map[string]string),
|
||||||
@@ -324,9 +325,9 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
|
|||||||
serviceFlags.StringVar(&config.PingTimeout, "ping-timeout", config.PingTimeout, "Timeout for each ping")
|
serviceFlags.StringVar(&config.PingTimeout, "ping-timeout", config.PingTimeout, "Timeout for each ping")
|
||||||
serviceFlags.BoolVar(&config.EnableAPI, "enable-api", config.EnableAPI, "Enable API server for receiving connection requests")
|
serviceFlags.BoolVar(&config.EnableAPI, "enable-api", config.EnableAPI, "Enable API server for receiving connection requests")
|
||||||
serviceFlags.BoolVar(&config.DisableHolepunch, "disable-holepunch", config.DisableHolepunch, "Disable hole punching")
|
serviceFlags.BoolVar(&config.DisableHolepunch, "disable-holepunch", config.DisableHolepunch, "Disable hole punching")
|
||||||
serviceFlags.BoolVar(&config.OverrideDNS, "override-dns", config.OverrideDNS, "Override system DNS settings")
|
serviceFlags.BoolVar(&config.OverrideDNS, "override-dns", config.OverrideDNS, "When enabled, the client uses custom DNS servers to resolve internal resources and aliases. This overrides your system's default DNS settings. Queries that cannot be resolved as a Pangolin resource will be forwarded to your configured Upstream DNS Server. (default false)")
|
||||||
serviceFlags.BoolVar(&config.DisableRelay, "disable-relay", config.DisableRelay, "Disable relay connections")
|
serviceFlags.BoolVar(&config.DisableRelay, "disable-relay", config.DisableRelay, "Disable relay connections")
|
||||||
serviceFlags.BoolVar(&config.TunnelDNS, "tunnel-dns", config.TunnelDNS, "Use tunnel for DNS traffic")
|
serviceFlags.BoolVar(&config.TunnelDNS, "tunnel-dns", config.TunnelDNS, "When enabled, DNS queries are routed through the tunnel for remote resolution. To ensure queries are tunneled correctly, you must define the DNS server as a Pangolin resource and enter its address as an Upstream DNS Server. (default false)")
|
||||||
// serviceFlags.BoolVar(&config.DoNotCreateNewClient, "do-not-create-new-client", config.DoNotCreateNewClient, "Do not create new client")
|
// serviceFlags.BoolVar(&config.DoNotCreateNewClient, "do-not-create-new-client", config.DoNotCreateNewClient, "Do not create new client")
|
||||||
|
|
||||||
version := serviceFlags.Bool("version", false, "Print the version")
|
version := serviceFlags.Bool("version", false, "Print the version")
|
||||||
|
|||||||
@@ -380,7 +380,7 @@ func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clie
|
|||||||
|
|
||||||
// Check if we have local records for this query
|
// Check if we have local records for this query
|
||||||
var response *dns.Msg
|
var response *dns.Msg
|
||||||
if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA {
|
if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA || question.Qtype == dns.TypePTR {
|
||||||
response = p.checkLocalRecords(msg, question)
|
response = p.checkLocalRecords(msg, question)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -410,6 +410,34 @@ func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clie
|
|||||||
|
|
||||||
// checkLocalRecords checks if we have local records for the query
|
// checkLocalRecords checks if we have local records for the query
|
||||||
func (p *DNSProxy) checkLocalRecords(query *dns.Msg, question dns.Question) *dns.Msg {
|
func (p *DNSProxy) checkLocalRecords(query *dns.Msg, question dns.Question) *dns.Msg {
|
||||||
|
// Handle PTR queries
|
||||||
|
if question.Qtype == dns.TypePTR {
|
||||||
|
if ptrDomain, ok := p.recordStore.GetPTRRecord(question.Name); ok {
|
||||||
|
logger.Debug("Found local PTR record for %s -> %s", question.Name, ptrDomain)
|
||||||
|
|
||||||
|
// Create response message
|
||||||
|
response := new(dns.Msg)
|
||||||
|
response.SetReply(query)
|
||||||
|
response.Authoritative = true
|
||||||
|
|
||||||
|
// Add PTR answer record
|
||||||
|
rr := &dns.PTR{
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: question.Name,
|
||||||
|
Rrtype: dns.TypePTR,
|
||||||
|
Class: dns.ClassINET,
|
||||||
|
Ttl: 300, // 5 minutes
|
||||||
|
},
|
||||||
|
Ptr: ptrDomain,
|
||||||
|
}
|
||||||
|
response.Answer = append(response.Answer, rr)
|
||||||
|
|
||||||
|
return response
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle A and AAAA queries
|
||||||
var recordType RecordType
|
var recordType RecordType
|
||||||
if question.Qtype == dns.TypeA {
|
if question.Qtype == dns.TypeA {
|
||||||
recordType = RecordTypeA
|
recordType = RecordTypeA
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -14,15 +15,17 @@ type RecordType uint16
|
|||||||
const (
|
const (
|
||||||
RecordTypeA RecordType = RecordType(dns.TypeA)
|
RecordTypeA RecordType = RecordType(dns.TypeA)
|
||||||
RecordTypeAAAA RecordType = RecordType(dns.TypeAAAA)
|
RecordTypeAAAA RecordType = RecordType(dns.TypeAAAA)
|
||||||
|
RecordTypePTR RecordType = RecordType(dns.TypePTR)
|
||||||
)
|
)
|
||||||
|
|
||||||
// DNSRecordStore manages local DNS records for A and AAAA queries
|
// DNSRecordStore manages local DNS records for A, AAAA, and PTR queries
|
||||||
type DNSRecordStore struct {
|
type DNSRecordStore struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
aRecords map[string][]net.IP // domain -> list of IPv4 addresses
|
aRecords map[string][]net.IP // domain -> list of IPv4 addresses
|
||||||
aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses
|
aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses
|
||||||
aWildcards map[string][]net.IP // wildcard pattern -> list of IPv4 addresses
|
aWildcards map[string][]net.IP // wildcard pattern -> list of IPv4 addresses
|
||||||
aaaaWildcards map[string][]net.IP // wildcard pattern -> list of IPv6 addresses
|
aaaaWildcards map[string][]net.IP // wildcard pattern -> list of IPv6 addresses
|
||||||
|
ptrRecords map[string]string // IP address string -> domain name
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDNSRecordStore creates a new DNS record store
|
// NewDNSRecordStore creates a new DNS record store
|
||||||
@@ -32,6 +35,7 @@ func NewDNSRecordStore() *DNSRecordStore {
|
|||||||
aaaaRecords: make(map[string][]net.IP),
|
aaaaRecords: make(map[string][]net.IP),
|
||||||
aWildcards: make(map[string][]net.IP),
|
aWildcards: make(map[string][]net.IP),
|
||||||
aaaaWildcards: make(map[string][]net.IP),
|
aaaaWildcards: make(map[string][]net.IP),
|
||||||
|
ptrRecords: make(map[string]string),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -39,6 +43,7 @@ func NewDNSRecordStore() *DNSRecordStore {
|
|||||||
// domain should be in FQDN format (e.g., "example.com.")
|
// domain should be in FQDN format (e.g., "example.com.")
|
||||||
// domain can contain wildcards: * (0+ chars) and ? (exactly 1 char)
|
// domain can contain wildcards: * (0+ chars) and ? (exactly 1 char)
|
||||||
// ip should be a valid IPv4 or IPv6 address
|
// ip should be a valid IPv4 or IPv6 address
|
||||||
|
// Automatically adds a corresponding PTR record for non-wildcard domains
|
||||||
func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
|
func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
@@ -48,8 +53,8 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
|
|||||||
domain = domain + "."
|
domain = domain + "."
|
||||||
}
|
}
|
||||||
|
|
||||||
// Normalize domain to lowercase
|
// Normalize domain to lowercase FQDN
|
||||||
domain = dns.Fqdn(domain)
|
domain = strings.ToLower(dns.Fqdn(domain))
|
||||||
|
|
||||||
// Check if domain contains wildcards
|
// Check if domain contains wildcards
|
||||||
isWildcard := strings.ContainsAny(domain, "*?")
|
isWildcard := strings.ContainsAny(domain, "*?")
|
||||||
@@ -60,6 +65,8 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
|
|||||||
s.aWildcards[domain] = append(s.aWildcards[domain], ip)
|
s.aWildcards[domain] = append(s.aWildcards[domain], ip)
|
||||||
} else {
|
} else {
|
||||||
s.aRecords[domain] = append(s.aRecords[domain], ip)
|
s.aRecords[domain] = append(s.aRecords[domain], ip)
|
||||||
|
// Automatically add PTR record for non-wildcard domains
|
||||||
|
s.ptrRecords[ip.String()] = domain
|
||||||
}
|
}
|
||||||
} else if ip.To16() != nil {
|
} else if ip.To16() != nil {
|
||||||
// IPv6 address
|
// IPv6 address
|
||||||
@@ -67,6 +74,8 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
|
|||||||
s.aaaaWildcards[domain] = append(s.aaaaWildcards[domain], ip)
|
s.aaaaWildcards[domain] = append(s.aaaaWildcards[domain], ip)
|
||||||
} else {
|
} else {
|
||||||
s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip)
|
s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip)
|
||||||
|
// Automatically add PTR record for non-wildcard domains
|
||||||
|
s.ptrRecords[ip.String()] = domain
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return &net.ParseError{Type: "IP address", Text: ip.String()}
|
return &net.ParseError{Type: "IP address", Text: ip.String()}
|
||||||
@@ -75,8 +84,30 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddPTRRecord adds a PTR record mapping an IP address to a domain name
|
||||||
|
// ip should be a valid IPv4 or IPv6 address
|
||||||
|
// domain should be in FQDN format (e.g., "example.com.")
|
||||||
|
func (s *DNSRecordStore) AddPTRRecord(ip net.IP, domain string) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
// Ensure domain ends with a dot (FQDN format)
|
||||||
|
if len(domain) == 0 || domain[len(domain)-1] != '.' {
|
||||||
|
domain = domain + "."
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize domain to lowercase FQDN
|
||||||
|
domain = strings.ToLower(dns.Fqdn(domain))
|
||||||
|
|
||||||
|
// Store PTR record using IP string as key
|
||||||
|
s.ptrRecords[ip.String()] = domain
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// RemoveRecord removes a specific DNS record mapping
|
// RemoveRecord removes a specific DNS record mapping
|
||||||
// If ip is nil, removes all records for the domain (including wildcards)
|
// If ip is nil, removes all records for the domain (including wildcards)
|
||||||
|
// Automatically removes corresponding PTR records for non-wildcard domains
|
||||||
func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
|
func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
@@ -86,8 +117,8 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
|
|||||||
domain = domain + "."
|
domain = domain + "."
|
||||||
}
|
}
|
||||||
|
|
||||||
// Normalize domain to lowercase
|
// Normalize domain to lowercase FQDN
|
||||||
domain = dns.Fqdn(domain)
|
domain = strings.ToLower(dns.Fqdn(domain))
|
||||||
|
|
||||||
// Check if domain contains wildcards
|
// Check if domain contains wildcards
|
||||||
isWildcard := strings.ContainsAny(domain, "*?")
|
isWildcard := strings.ContainsAny(domain, "*?")
|
||||||
@@ -98,6 +129,23 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
|
|||||||
delete(s.aWildcards, domain)
|
delete(s.aWildcards, domain)
|
||||||
delete(s.aaaaWildcards, domain)
|
delete(s.aaaaWildcards, domain)
|
||||||
} else {
|
} else {
|
||||||
|
// For non-wildcard domains, remove PTR records for all IPs
|
||||||
|
if ips, ok := s.aRecords[domain]; ok {
|
||||||
|
for _, ipAddr := range ips {
|
||||||
|
// Only remove PTR if it points to this domain
|
||||||
|
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
|
||||||
|
delete(s.ptrRecords, ipAddr.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ips, ok := s.aaaaRecords[domain]; ok {
|
||||||
|
for _, ipAddr := range ips {
|
||||||
|
// Only remove PTR if it points to this domain
|
||||||
|
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
|
||||||
|
delete(s.ptrRecords, ipAddr.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
delete(s.aRecords, domain)
|
delete(s.aRecords, domain)
|
||||||
delete(s.aaaaRecords, domain)
|
delete(s.aaaaRecords, domain)
|
||||||
}
|
}
|
||||||
@@ -119,6 +167,10 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
|
|||||||
if len(s.aRecords[domain]) == 0 {
|
if len(s.aRecords[domain]) == 0 {
|
||||||
delete(s.aRecords, domain)
|
delete(s.aRecords, domain)
|
||||||
}
|
}
|
||||||
|
// Automatically remove PTR record if it points to this domain
|
||||||
|
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
|
||||||
|
delete(s.ptrRecords, ip.String())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if ip.To16() != nil {
|
} else if ip.To16() != nil {
|
||||||
@@ -136,11 +188,23 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
|
|||||||
if len(s.aaaaRecords[domain]) == 0 {
|
if len(s.aaaaRecords[domain]) == 0 {
|
||||||
delete(s.aaaaRecords, domain)
|
delete(s.aaaaRecords, domain)
|
||||||
}
|
}
|
||||||
|
// Automatically remove PTR record if it points to this domain
|
||||||
|
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
|
||||||
|
delete(s.ptrRecords, ip.String())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RemovePTRRecord removes a PTR record for an IP address
|
||||||
|
func (s *DNSRecordStore) RemovePTRRecord(ip net.IP) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
delete(s.ptrRecords, ip.String())
|
||||||
|
}
|
||||||
|
|
||||||
// GetRecords returns all IP addresses for a domain and record type
|
// GetRecords returns all IP addresses for a domain and record type
|
||||||
// First checks for exact matches, then checks wildcard patterns
|
// First checks for exact matches, then checks wildcard patterns
|
||||||
func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP {
|
func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP {
|
||||||
@@ -148,7 +212,7 @@ func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.
|
|||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
// Normalize domain to lowercase FQDN
|
// Normalize domain to lowercase FQDN
|
||||||
domain = dns.Fqdn(domain)
|
domain = strings.ToLower(dns.Fqdn(domain))
|
||||||
|
|
||||||
var records []net.IP
|
var records []net.IP
|
||||||
switch recordType {
|
switch recordType {
|
||||||
@@ -198,6 +262,26 @@ func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.
|
|||||||
return records
|
return records
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPTRRecord returns the domain name for a PTR record query
|
||||||
|
// domain should be in reverse DNS format (e.g., "1.0.0.127.in-addr.arpa.")
|
||||||
|
func (s *DNSRecordStore) GetPTRRecord(domain string) (string, bool) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
// Convert reverse DNS format to IP address
|
||||||
|
ip := reverseDNSToIP(domain)
|
||||||
|
if ip == nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look up the PTR record
|
||||||
|
if ptrDomain, ok := s.ptrRecords[ip.String()]; ok {
|
||||||
|
return ptrDomain, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
// HasRecord checks if a domain has any records of the specified type
|
// HasRecord checks if a domain has any records of the specified type
|
||||||
// Checks both exact matches and wildcard patterns
|
// Checks both exact matches and wildcard patterns
|
||||||
func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
|
func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
|
||||||
@@ -205,7 +289,7 @@ func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
|
|||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
// Normalize domain to lowercase FQDN
|
// Normalize domain to lowercase FQDN
|
||||||
domain = dns.Fqdn(domain)
|
domain = strings.ToLower(dns.Fqdn(domain))
|
||||||
|
|
||||||
switch recordType {
|
switch recordType {
|
||||||
case RecordTypeA:
|
case RecordTypeA:
|
||||||
@@ -235,6 +319,21 @@ func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HasPTRRecord checks if a PTR record exists for the given reverse DNS domain
|
||||||
|
func (s *DNSRecordStore) HasPTRRecord(domain string) bool {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
// Convert reverse DNS format to IP address
|
||||||
|
ip := reverseDNSToIP(domain)
|
||||||
|
if ip == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
_, ok := s.ptrRecords[ip.String()]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
// Clear removes all records from the store
|
// Clear removes all records from the store
|
||||||
func (s *DNSRecordStore) Clear() {
|
func (s *DNSRecordStore) Clear() {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
@@ -244,6 +343,7 @@ func (s *DNSRecordStore) Clear() {
|
|||||||
s.aaaaRecords = make(map[string][]net.IP)
|
s.aaaaRecords = make(map[string][]net.IP)
|
||||||
s.aWildcards = make(map[string][]net.IP)
|
s.aWildcards = make(map[string][]net.IP)
|
||||||
s.aaaaWildcards = make(map[string][]net.IP)
|
s.aaaaWildcards = make(map[string][]net.IP)
|
||||||
|
s.ptrRecords = make(map[string]string)
|
||||||
}
|
}
|
||||||
|
|
||||||
// removeIP is a helper function to remove a specific IP from a slice
|
// removeIP is a helper function to remove a specific IP from a slice
|
||||||
@@ -323,3 +423,75 @@ func matchWildcardInternal(pattern, domain string, pi, di int) bool {
|
|||||||
|
|
||||||
return matchWildcardInternal(pattern, domain, pi+1, di+1)
|
return matchWildcardInternal(pattern, domain, pi+1, di+1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// reverseDNSToIP converts a reverse DNS query name to an IP address
|
||||||
|
// Supports both IPv4 (in-addr.arpa) and IPv6 (ip6.arpa) formats
|
||||||
|
func reverseDNSToIP(domain string) net.IP {
|
||||||
|
// Normalize to lowercase and ensure FQDN
|
||||||
|
domain = strings.ToLower(dns.Fqdn(domain))
|
||||||
|
|
||||||
|
// Check for IPv4 reverse DNS (in-addr.arpa)
|
||||||
|
if strings.HasSuffix(domain, ".in-addr.arpa.") {
|
||||||
|
// Remove the suffix
|
||||||
|
ipPart := strings.TrimSuffix(domain, ".in-addr.arpa.")
|
||||||
|
// Split by dots and reverse
|
||||||
|
parts := strings.Split(ipPart, ".")
|
||||||
|
if len(parts) != 4 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// Reverse the octets
|
||||||
|
reversed := make([]string, 4)
|
||||||
|
for i := 0; i < 4; i++ {
|
||||||
|
reversed[i] = parts[3-i]
|
||||||
|
}
|
||||||
|
// Parse as IP
|
||||||
|
return net.ParseIP(strings.Join(reversed, "."))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for IPv6 reverse DNS (ip6.arpa)
|
||||||
|
if strings.HasSuffix(domain, ".ip6.arpa.") {
|
||||||
|
// Remove the suffix
|
||||||
|
ipPart := strings.TrimSuffix(domain, ".ip6.arpa.")
|
||||||
|
// Split by dots and reverse
|
||||||
|
parts := strings.Split(ipPart, ".")
|
||||||
|
if len(parts) != 32 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// Reverse the nibbles and group into 16-bit hex values
|
||||||
|
reversed := make([]string, 32)
|
||||||
|
for i := 0; i < 32; i++ {
|
||||||
|
reversed[i] = parts[31-i]
|
||||||
|
}
|
||||||
|
// Join into IPv6 format (groups of 4 nibbles separated by colons)
|
||||||
|
var ipv6Parts []string
|
||||||
|
for i := 0; i < 32; i += 4 {
|
||||||
|
ipv6Parts = append(ipv6Parts, reversed[i]+reversed[i+1]+reversed[i+2]+reversed[i+3])
|
||||||
|
}
|
||||||
|
// Parse as IP
|
||||||
|
return net.ParseIP(strings.Join(ipv6Parts, ":"))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPToReverseDNS converts an IP address to reverse DNS format
|
||||||
|
// Returns the domain name for PTR queries (e.g., "1.0.0.127.in-addr.arpa.")
|
||||||
|
func IPToReverseDNS(ip net.IP) string {
|
||||||
|
if ip4 := ip.To4(); ip4 != nil {
|
||||||
|
// IPv4: reverse octets and append .in-addr.arpa.
|
||||||
|
return dns.Fqdn(fmt.Sprintf("%d.%d.%d.%d.in-addr.arpa",
|
||||||
|
ip4[3], ip4[2], ip4[1], ip4[0]))
|
||||||
|
}
|
||||||
|
|
||||||
|
if ip6 := ip.To16(); ip6 != nil && ip.To4() == nil {
|
||||||
|
// IPv6: expand to 32 nibbles, reverse, and append .ip6.arpa.
|
||||||
|
var nibbles []string
|
||||||
|
for i := 15; i >= 0; i-- {
|
||||||
|
nibbles = append(nibbles, fmt.Sprintf("%x", ip6[i]&0x0f))
|
||||||
|
nibbles = append(nibbles, fmt.Sprintf("%x", ip6[i]>>4))
|
||||||
|
}
|
||||||
|
return dns.Fqdn(strings.Join(nibbles, ".") + ".ip6.arpa")
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
@@ -348,3 +348,517 @@ func TestHasRecordWildcard(t *testing.T) {
|
|||||||
t.Error("Expected HasRecord to return false for base domain")
|
t.Error("Expected HasRecord to return false for base domain")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
|
||||||
|
store := NewDNSRecordStore()
|
||||||
|
|
||||||
|
// Add record with mixed case
|
||||||
|
ip := net.ParseIP("10.0.0.1")
|
||||||
|
err := store.AddRecord("MyHost.AutoCo.Internal", ip)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to add mixed case record: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test lookup with different cases
|
||||||
|
testCases := []string{
|
||||||
|
"myhost.autoco.internal.",
|
||||||
|
"MYHOST.AUTOCO.INTERNAL.",
|
||||||
|
"MyHost.AutoCo.Internal.",
|
||||||
|
"mYhOsT.aUtOcO.iNtErNaL.",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, domain := range testCases {
|
||||||
|
ips := store.GetRecords(domain, RecordTypeA)
|
||||||
|
if len(ips) != 1 {
|
||||||
|
t.Errorf("Expected 1 IP for domain %q, got %d", domain, len(ips))
|
||||||
|
}
|
||||||
|
if len(ips) > 0 && !ips[0].Equal(ip) {
|
||||||
|
t.Errorf("Expected IP %v for domain %q, got %v", ip, domain, ips[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test wildcard with mixed case
|
||||||
|
wildcardIP := net.ParseIP("10.0.0.2")
|
||||||
|
err = store.AddRecord("*.Example.Com", wildcardIP)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to add mixed case wildcard: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
wildcardTestCases := []string{
|
||||||
|
"host.example.com.",
|
||||||
|
"HOST.EXAMPLE.COM.",
|
||||||
|
"Host.Example.Com.",
|
||||||
|
"HoSt.ExAmPlE.CoM.",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, domain := range wildcardTestCases {
|
||||||
|
ips := store.GetRecords(domain, RecordTypeA)
|
||||||
|
if len(ips) != 1 {
|
||||||
|
t.Errorf("Expected 1 IP for wildcard domain %q, got %d", domain, len(ips))
|
||||||
|
}
|
||||||
|
if len(ips) > 0 && !ips[0].Equal(wildcardIP) {
|
||||||
|
t.Errorf("Expected IP %v for wildcard domain %q, got %v", wildcardIP, domain, ips[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test removal with different case
|
||||||
|
store.RemoveRecord("MYHOST.AUTOCO.INTERNAL", nil)
|
||||||
|
ips := store.GetRecords("myhost.autoco.internal.", RecordTypeA)
|
||||||
|
if len(ips) != 0 {
|
||||||
|
t.Errorf("Expected 0 IPs after removal, got %d", len(ips))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test HasRecord with different case
|
||||||
|
if !store.HasRecord("HOST.EXAMPLE.COM.", RecordTypeA) {
|
||||||
|
t.Error("Expected HasRecord to return true for mixed case wildcard match")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPTRRecordIPv4(t *testing.T) {
|
||||||
|
store := NewDNSRecordStore()
|
||||||
|
|
||||||
|
// Add PTR record for IPv4
|
||||||
|
ip := net.ParseIP("192.168.1.1")
|
||||||
|
domain := "host.example.com."
|
||||||
|
err := store.AddPTRRecord(ip, domain)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to add PTR record: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test reverse DNS lookup
|
||||||
|
reverseDomain := "1.1.168.192.in-addr.arpa."
|
||||||
|
result, ok := store.GetPTRRecord(reverseDomain)
|
||||||
|
if !ok {
|
||||||
|
t.Error("Expected PTR record to be found")
|
||||||
|
}
|
||||||
|
if result != domain {
|
||||||
|
t.Errorf("Expected domain %q, got %q", domain, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test HasPTRRecord
|
||||||
|
if !store.HasPTRRecord(reverseDomain) {
|
||||||
|
t.Error("Expected HasPTRRecord to return true")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test non-existent PTR record
|
||||||
|
_, ok = store.GetPTRRecord("2.1.168.192.in-addr.arpa.")
|
||||||
|
if ok {
|
||||||
|
t.Error("Expected PTR record not to be found for different IP")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPTRRecordIPv6(t *testing.T) {
|
||||||
|
store := NewDNSRecordStore()
|
||||||
|
|
||||||
|
// Add PTR record for IPv6
|
||||||
|
ip := net.ParseIP("2001:db8::1")
|
||||||
|
domain := "ipv6host.example.com."
|
||||||
|
err := store.AddPTRRecord(ip, domain)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to add PTR record: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test reverse DNS lookup
|
||||||
|
// 2001:db8::1 = 2001:0db8:0000:0000:0000:0000:0000:0001
|
||||||
|
// Reverse: 1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.
|
||||||
|
reverseDomain := "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa."
|
||||||
|
result, ok := store.GetPTRRecord(reverseDomain)
|
||||||
|
if !ok {
|
||||||
|
t.Error("Expected IPv6 PTR record to be found")
|
||||||
|
}
|
||||||
|
if result != domain {
|
||||||
|
t.Errorf("Expected domain %q, got %q", domain, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test HasPTRRecord
|
||||||
|
if !store.HasPTRRecord(reverseDomain) {
|
||||||
|
t.Error("Expected HasPTRRecord to return true for IPv6")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRemovePTRRecord(t *testing.T) {
|
||||||
|
store := NewDNSRecordStore()
|
||||||
|
|
||||||
|
// Add PTR record
|
||||||
|
ip := net.ParseIP("10.0.0.1")
|
||||||
|
domain := "test.example.com."
|
||||||
|
err := store.AddPTRRecord(ip, domain)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to add PTR record: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it exists
|
||||||
|
reverseDomain := "1.0.0.10.in-addr.arpa."
|
||||||
|
_, ok := store.GetPTRRecord(reverseDomain)
|
||||||
|
if !ok {
|
||||||
|
t.Error("Expected PTR record to exist before removal")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove PTR record
|
||||||
|
store.RemovePTRRecord(ip)
|
||||||
|
|
||||||
|
// Verify it's gone
|
||||||
|
_, ok = store.GetPTRRecord(reverseDomain)
|
||||||
|
if ok {
|
||||||
|
t.Error("Expected PTR record to be removed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test HasPTRRecord after removal
|
||||||
|
if store.HasPTRRecord(reverseDomain) {
|
||||||
|
t.Error("Expected HasPTRRecord to return false after removal")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIPToReverseDNS(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ip string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "IPv4 simple",
|
||||||
|
ip: "192.168.1.1",
|
||||||
|
expected: "1.1.168.192.in-addr.arpa.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv4 localhost",
|
||||||
|
ip: "127.0.0.1",
|
||||||
|
expected: "1.0.0.127.in-addr.arpa.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv4 with zeros",
|
||||||
|
ip: "10.0.0.1",
|
||||||
|
expected: "1.0.0.10.in-addr.arpa.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 simple",
|
||||||
|
ip: "2001:db8::1",
|
||||||
|
expected: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 localhost",
|
||||||
|
ip: "::1",
|
||||||
|
expected: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ip := net.ParseIP(tt.ip)
|
||||||
|
if ip == nil {
|
||||||
|
t.Fatalf("Failed to parse IP: %s", tt.ip)
|
||||||
|
}
|
||||||
|
result := IPToReverseDNS(ip)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("IPToReverseDNS(%s) = %q, want %q", tt.ip, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseDNSToIP(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
reverseDNS string
|
||||||
|
expectedIP string
|
||||||
|
shouldMatch bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "IPv4 simple",
|
||||||
|
reverseDNS: "1.1.168.192.in-addr.arpa.",
|
||||||
|
expectedIP: "192.168.1.1",
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv4 localhost",
|
||||||
|
reverseDNS: "1.0.0.127.in-addr.arpa.",
|
||||||
|
expectedIP: "127.0.0.1",
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 simple",
|
||||||
|
reverseDNS: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
|
||||||
|
expectedIP: "2001:db8::1",
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid IPv4 format",
|
||||||
|
reverseDNS: "1.1.168.in-addr.arpa.",
|
||||||
|
expectedIP: "",
|
||||||
|
shouldMatch: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid IPv6 format",
|
||||||
|
reverseDNS: "1.0.0.0.ip6.arpa.",
|
||||||
|
expectedIP: "",
|
||||||
|
shouldMatch: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Not a reverse DNS domain",
|
||||||
|
reverseDNS: "example.com.",
|
||||||
|
expectedIP: "",
|
||||||
|
shouldMatch: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := reverseDNSToIP(tt.reverseDNS)
|
||||||
|
if tt.shouldMatch {
|
||||||
|
if result == nil {
|
||||||
|
t.Errorf("reverseDNSToIP(%q) returned nil, expected IP", tt.reverseDNS)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
expectedIP := net.ParseIP(tt.expectedIP)
|
||||||
|
if !result.Equal(expectedIP) {
|
||||||
|
t.Errorf("reverseDNSToIP(%q) = %v, want %v", tt.reverseDNS, result, expectedIP)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if result != nil {
|
||||||
|
t.Errorf("reverseDNSToIP(%q) = %v, expected nil", tt.reverseDNS, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPTRRecordCaseInsensitive(t *testing.T) {
|
||||||
|
store := NewDNSRecordStore()
|
||||||
|
|
||||||
|
// Add PTR record with mixed case domain
|
||||||
|
ip := net.ParseIP("192.168.1.1")
|
||||||
|
domain := "MyHost.Example.Com"
|
||||||
|
err := store.AddPTRRecord(ip, domain)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to add PTR record: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test lookup with different cases in reverse DNS format
|
||||||
|
reverseDomain := "1.1.168.192.in-addr.arpa."
|
||||||
|
result, ok := store.GetPTRRecord(reverseDomain)
|
||||||
|
if !ok {
|
||||||
|
t.Error("Expected PTR record to be found")
|
||||||
|
}
|
||||||
|
// Domain should be normalized to lowercase
|
||||||
|
if result != "myhost.example.com." {
|
||||||
|
t.Errorf("Expected normalized domain %q, got %q", "myhost.example.com.", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with uppercase reverse DNS
|
||||||
|
reverseDomainUpper := "1.1.168.192.IN-ADDR.ARPA."
|
||||||
|
result, ok = store.GetPTRRecord(reverseDomainUpper)
|
||||||
|
if !ok {
|
||||||
|
t.Error("Expected PTR record to be found with uppercase reverse DNS")
|
||||||
|
}
|
||||||
|
if result != "myhost.example.com." {
|
||||||
|
t.Errorf("Expected normalized domain %q, got %q", "myhost.example.com.", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClearPTRRecords(t *testing.T) {
|
||||||
|
store := NewDNSRecordStore()
|
||||||
|
|
||||||
|
// Add some PTR records
|
||||||
|
ip1 := net.ParseIP("192.168.1.1")
|
||||||
|
ip2 := net.ParseIP("192.168.1.2")
|
||||||
|
store.AddPTRRecord(ip1, "host1.example.com.")
|
||||||
|
store.AddPTRRecord(ip2, "host2.example.com.")
|
||||||
|
|
||||||
|
// Add some A records too
|
||||||
|
store.AddRecord("test.example.com.", net.ParseIP("10.0.0.1"))
|
||||||
|
|
||||||
|
// Verify PTR records exist
|
||||||
|
if !store.HasPTRRecord("1.1.168.192.in-addr.arpa.") {
|
||||||
|
t.Error("Expected PTR record to exist before clear")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear all records
|
||||||
|
store.Clear()
|
||||||
|
|
||||||
|
// Verify PTR records are gone
|
||||||
|
if store.HasPTRRecord("1.1.168.192.in-addr.arpa.") {
|
||||||
|
t.Error("Expected PTR record to be cleared")
|
||||||
|
}
|
||||||
|
if store.HasPTRRecord("2.1.168.192.in-addr.arpa.") {
|
||||||
|
t.Error("Expected PTR record to be cleared")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify A records are also gone
|
||||||
|
if store.HasRecord("test.example.com.", RecordTypeA) {
|
||||||
|
t.Error("Expected A record to be cleared")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAutomaticPTRRecordOnAdd(t *testing.T) {
|
||||||
|
store := NewDNSRecordStore()
|
||||||
|
|
||||||
|
// Add an A record - should automatically add PTR record
|
||||||
|
domain := "host.example.com."
|
||||||
|
ip := net.ParseIP("192.168.1.100")
|
||||||
|
err := store.AddRecord(domain, ip)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to add A record: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify PTR record was automatically created
|
||||||
|
reverseDomain := "100.1.168.192.in-addr.arpa."
|
||||||
|
result, ok := store.GetPTRRecord(reverseDomain)
|
||||||
|
if !ok {
|
||||||
|
t.Error("Expected PTR record to be automatically created")
|
||||||
|
}
|
||||||
|
if result != domain {
|
||||||
|
t.Errorf("Expected PTR to point to %q, got %q", domain, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add AAAA record - should also automatically add PTR record
|
||||||
|
domain6 := "ipv6host.example.com."
|
||||||
|
ip6 := net.ParseIP("2001:db8::1")
|
||||||
|
err = store.AddRecord(domain6, ip6)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to add AAAA record: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify IPv6 PTR record was automatically created
|
||||||
|
reverseDomain6 := "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa."
|
||||||
|
result6, ok := store.GetPTRRecord(reverseDomain6)
|
||||||
|
if !ok {
|
||||||
|
t.Error("Expected IPv6 PTR record to be automatically created")
|
||||||
|
}
|
||||||
|
if result6 != domain6 {
|
||||||
|
t.Errorf("Expected PTR to point to %q, got %q", domain6, result6)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAutomaticPTRRecordOnRemove(t *testing.T) {
|
||||||
|
store := NewDNSRecordStore()
|
||||||
|
|
||||||
|
// Add an A record (with automatic PTR)
|
||||||
|
domain := "host.example.com."
|
||||||
|
ip := net.ParseIP("192.168.1.100")
|
||||||
|
store.AddRecord(domain, ip)
|
||||||
|
|
||||||
|
// Verify PTR exists
|
||||||
|
reverseDomain := "100.1.168.192.in-addr.arpa."
|
||||||
|
if !store.HasPTRRecord(reverseDomain) {
|
||||||
|
t.Error("Expected PTR record to exist after adding A record")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the A record
|
||||||
|
store.RemoveRecord(domain, ip)
|
||||||
|
|
||||||
|
// Verify PTR was automatically removed
|
||||||
|
if store.HasPTRRecord(reverseDomain) {
|
||||||
|
t.Error("Expected PTR record to be automatically removed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify A record is also gone
|
||||||
|
ips := store.GetRecords(domain, RecordTypeA)
|
||||||
|
if len(ips) != 0 {
|
||||||
|
t.Errorf("Expected A record to be removed, got %d records", len(ips))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAutomaticPTRRecordOnRemoveAll(t *testing.T) {
|
||||||
|
store := NewDNSRecordStore()
|
||||||
|
|
||||||
|
// Add multiple IPs for the same domain
|
||||||
|
domain := "host.example.com."
|
||||||
|
ip1 := net.ParseIP("192.168.1.100")
|
||||||
|
ip2 := net.ParseIP("192.168.1.101")
|
||||||
|
store.AddRecord(domain, ip1)
|
||||||
|
store.AddRecord(domain, ip2)
|
||||||
|
|
||||||
|
// Verify both PTR records exist
|
||||||
|
reverseDomain1 := "100.1.168.192.in-addr.arpa."
|
||||||
|
reverseDomain2 := "101.1.168.192.in-addr.arpa."
|
||||||
|
if !store.HasPTRRecord(reverseDomain1) {
|
||||||
|
t.Error("Expected first PTR record to exist")
|
||||||
|
}
|
||||||
|
if !store.HasPTRRecord(reverseDomain2) {
|
||||||
|
t.Error("Expected second PTR record to exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove all records for the domain
|
||||||
|
store.RemoveRecord(domain, nil)
|
||||||
|
|
||||||
|
// Verify both PTR records were removed
|
||||||
|
if store.HasPTRRecord(reverseDomain1) {
|
||||||
|
t.Error("Expected first PTR record to be removed")
|
||||||
|
}
|
||||||
|
if store.HasPTRRecord(reverseDomain2) {
|
||||||
|
t.Error("Expected second PTR record to be removed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNoPTRForWildcardRecords(t *testing.T) {
|
||||||
|
store := NewDNSRecordStore()
|
||||||
|
|
||||||
|
// Add wildcard record - should NOT create PTR record
|
||||||
|
domain := "*.example.com."
|
||||||
|
ip := net.ParseIP("192.168.1.100")
|
||||||
|
err := store.AddRecord(domain, ip)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to add wildcard record: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify no PTR record was created
|
||||||
|
reverseDomain := "100.1.168.192.in-addr.arpa."
|
||||||
|
_, ok := store.GetPTRRecord(reverseDomain)
|
||||||
|
if ok {
|
||||||
|
t.Error("Expected no PTR record for wildcard domain")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify wildcard A record exists
|
||||||
|
if !store.HasRecord("host.example.com.", RecordTypeA) {
|
||||||
|
t.Error("Expected wildcard A record to exist")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPTRRecordOverwrite(t *testing.T) {
|
||||||
|
store := NewDNSRecordStore()
|
||||||
|
|
||||||
|
// Add first domain with IP
|
||||||
|
domain1 := "host1.example.com."
|
||||||
|
ip := net.ParseIP("192.168.1.100")
|
||||||
|
store.AddRecord(domain1, ip)
|
||||||
|
|
||||||
|
// Verify PTR points to first domain
|
||||||
|
reverseDomain := "100.1.168.192.in-addr.arpa."
|
||||||
|
result, ok := store.GetPTRRecord(reverseDomain)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Expected PTR record to exist")
|
||||||
|
}
|
||||||
|
if result != domain1 {
|
||||||
|
t.Errorf("Expected PTR to point to %q, got %q", domain1, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add second domain with same IP - should overwrite PTR
|
||||||
|
domain2 := "host2.example.com."
|
||||||
|
store.AddRecord(domain2, ip)
|
||||||
|
|
||||||
|
// Verify PTR now points to second domain (last one added)
|
||||||
|
result, ok = store.GetPTRRecord(reverseDomain)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Expected PTR record to still exist")
|
||||||
|
}
|
||||||
|
if result != domain2 {
|
||||||
|
t.Errorf("Expected PTR to point to %q (overwritten), got %q", domain2, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove first domain - PTR should remain pointing to second domain
|
||||||
|
store.RemoveRecord(domain1, ip)
|
||||||
|
result, ok = store.GetPTRRecord(reverseDomain)
|
||||||
|
if !ok {
|
||||||
|
t.Error("Expected PTR record to still exist after removing first domain")
|
||||||
|
}
|
||||||
|
if result != domain2 {
|
||||||
|
t.Errorf("Expected PTR to still point to %q, got %q", domain2, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove second domain - PTR should now be gone
|
||||||
|
store.RemoveRecord(domain2, ip)
|
||||||
|
_, ok = store.GetPTRRecord(reverseDomain)
|
||||||
|
if ok {
|
||||||
|
t.Error("Expected PTR record to be removed after removing second domain")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -28,9 +28,15 @@ type OlmErrorData struct {
|
|||||||
func (o *Olm) handleConnect(msg websocket.WSMessage) {
|
func (o *Olm) handleConnect(msg websocket.WSMessage) {
|
||||||
logger.Debug("Received message: %v", msg.Data)
|
logger.Debug("Received message: %v", msg.Data)
|
||||||
|
|
||||||
|
// Check if tunnel is still running
|
||||||
|
if !o.tunnelRunning {
|
||||||
|
logger.Debug("Tunnel stopped, ignoring connect message")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var wgData WgData
|
var wgData WgData
|
||||||
|
|
||||||
if o.connected {
|
if o.registered {
|
||||||
logger.Info("Already connected. Ignoring new connection request.")
|
logger.Info("Already connected. Ignoring new connection request.")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -202,7 +208,7 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) {
|
|||||||
|
|
||||||
o.apiServer.SetRegistered(true)
|
o.apiServer.SetRegistered(true)
|
||||||
|
|
||||||
o.connected = true
|
o.registered = true
|
||||||
|
|
||||||
// Start ping monitor now that we are registered and connected
|
// Start ping monitor now that we are registered and connected
|
||||||
o.websocket.StartPingMonitor()
|
o.websocket.StartPingMonitor()
|
||||||
@@ -218,6 +224,12 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) {
|
|||||||
func (o *Olm) handleOlmError(msg websocket.WSMessage) {
|
func (o *Olm) handleOlmError(msg websocket.WSMessage) {
|
||||||
logger.Debug("Received olm error message: %v", msg.Data)
|
logger.Debug("Received olm error message: %v", msg.Data)
|
||||||
|
|
||||||
|
// Check if tunnel is still running
|
||||||
|
if !o.tunnelRunning {
|
||||||
|
logger.Debug("Tunnel stopped, ignoring olm error message")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var errorData OlmErrorData
|
var errorData OlmErrorData
|
||||||
|
|
||||||
jsonData, err := json.Marshal(msg.Data)
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
@@ -245,6 +257,12 @@ func (o *Olm) handleOlmError(msg websocket.WSMessage) {
|
|||||||
func (o *Olm) handleTerminate(msg websocket.WSMessage) {
|
func (o *Olm) handleTerminate(msg websocket.WSMessage) {
|
||||||
logger.Info("Received terminate message")
|
logger.Info("Received terminate message")
|
||||||
|
|
||||||
|
// Check if tunnel is still running
|
||||||
|
if !o.tunnelRunning {
|
||||||
|
logger.Debug("Tunnel stopped, ignoring terminate message")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var errorData OlmErrorData
|
var errorData OlmErrorData
|
||||||
|
|
||||||
jsonData, err := json.Marshal(msg.Data)
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
@@ -255,6 +273,12 @@ func (o *Olm) handleTerminate(msg websocket.WSMessage) {
|
|||||||
logger.Error("Error unmarshaling terminate error data: %v", err)
|
logger.Error("Error unmarshaling terminate error data: %v", err)
|
||||||
} else {
|
} else {
|
||||||
logger.Info("Terminate reason (code: %s): %s", errorData.Code, errorData.Message)
|
logger.Info("Terminate reason (code: %s): %s", errorData.Code, errorData.Message)
|
||||||
|
|
||||||
|
if errorData.Code == "TERMINATED_INACTIVITY" {
|
||||||
|
logger.Info("Ignoring...")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Set the olm error in the API server so it can be exposed via status
|
// Set the olm error in the API server so it can be exposed via status
|
||||||
o.apiServer.SetOlmError(errorData.Code, errorData.Message)
|
o.apiServer.SetOlmError(errorData.Code, errorData.Message)
|
||||||
}
|
}
|
||||||
|
|||||||
20
olm/data.go
20
olm/data.go
@@ -13,6 +13,12 @@ import (
|
|||||||
func (o *Olm) handleWgPeerAddData(msg websocket.WSMessage) {
|
func (o *Olm) handleWgPeerAddData(msg websocket.WSMessage) {
|
||||||
logger.Debug("Received add-remote-subnets-aliases message: %v", msg.Data)
|
logger.Debug("Received add-remote-subnets-aliases message: %v", msg.Data)
|
||||||
|
|
||||||
|
// Check if tunnel is still running
|
||||||
|
if !o.tunnelRunning {
|
||||||
|
logger.Debug("Tunnel stopped, ignoring add-remote-subnets-aliases message")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
jsonData, err := json.Marshal(msg.Data)
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Error marshaling data: %v", err)
|
logger.Error("Error marshaling data: %v", err)
|
||||||
@@ -48,6 +54,12 @@ func (o *Olm) handleWgPeerAddData(msg websocket.WSMessage) {
|
|||||||
func (o *Olm) handleWgPeerRemoveData(msg websocket.WSMessage) {
|
func (o *Olm) handleWgPeerRemoveData(msg websocket.WSMessage) {
|
||||||
logger.Debug("Received remove-remote-subnets-aliases message: %v", msg.Data)
|
logger.Debug("Received remove-remote-subnets-aliases message: %v", msg.Data)
|
||||||
|
|
||||||
|
// Check if tunnel is still running
|
||||||
|
if !o.tunnelRunning {
|
||||||
|
logger.Debug("Tunnel stopped, ignoring remove-remote-subnets-aliases message")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
jsonData, err := json.Marshal(msg.Data)
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Error marshaling data: %v", err)
|
logger.Error("Error marshaling data: %v", err)
|
||||||
@@ -83,6 +95,12 @@ func (o *Olm) handleWgPeerRemoveData(msg websocket.WSMessage) {
|
|||||||
func (o *Olm) handleWgPeerUpdateData(msg websocket.WSMessage) {
|
func (o *Olm) handleWgPeerUpdateData(msg websocket.WSMessage) {
|
||||||
logger.Debug("Received update-remote-subnets-aliases message: %v", msg.Data)
|
logger.Debug("Received update-remote-subnets-aliases message: %v", msg.Data)
|
||||||
|
|
||||||
|
// Check if tunnel is still running
|
||||||
|
if !o.tunnelRunning {
|
||||||
|
logger.Debug("Tunnel stopped, ignoring update-remote-subnets-aliases message")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
jsonData, err := json.Marshal(msg.Data)
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Error marshaling data: %v", err)
|
logger.Error("Error marshaling data: %v", err)
|
||||||
@@ -139,7 +157,7 @@ func (o *Olm) handleWgPeerUpdateData(msg websocket.WSMessage) {
|
|||||||
func (o *Olm) handleSync(msg websocket.WSMessage) {
|
func (o *Olm) handleSync(msg websocket.WSMessage) {
|
||||||
logger.Debug("Received sync message: %v", msg.Data)
|
logger.Debug("Received sync message: %v", msg.Data)
|
||||||
|
|
||||||
if !o.connected {
|
if !o.registered {
|
||||||
logger.Warn("Not connected, ignoring sync request")
|
logger.Warn("Not connected, ignoring sync request")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
83
olm/olm.go
83
olm/olm.go
@@ -31,7 +31,7 @@ type Olm struct {
|
|||||||
privateKey wgtypes.Key
|
privateKey wgtypes.Key
|
||||||
logFile *os.File
|
logFile *os.File
|
||||||
|
|
||||||
connected bool
|
registered bool
|
||||||
tunnelRunning bool
|
tunnelRunning bool
|
||||||
|
|
||||||
uapiListener net.Listener
|
uapiListener net.Listener
|
||||||
@@ -66,6 +66,9 @@ type Olm struct {
|
|||||||
updateRegister func(newData any)
|
updateRegister func(newData any)
|
||||||
|
|
||||||
stopPeerSend func()
|
stopPeerSend func()
|
||||||
|
|
||||||
|
// WaitGroup to track tunnel lifecycle
|
||||||
|
tunnelWg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
// initTunnelInfo creates the shared UDP socket and holepunch manager.
|
// initTunnelInfo creates the shared UDP socket and holepunch manager.
|
||||||
@@ -382,10 +385,16 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
|||||||
|
|
||||||
o.apiServer.SetConnectionStatus(true)
|
o.apiServer.SetConnectionStatus(true)
|
||||||
|
|
||||||
if o.connected {
|
if o.registered {
|
||||||
o.websocket.StartPingMonitor()
|
o.websocket.StartPingMonitor()
|
||||||
|
|
||||||
logger.Debug("Already connected, skipping registration")
|
logger.Debug("Already registered, skipping registration")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if tunnel is still running before starting registration
|
||||||
|
if !o.tunnelRunning {
|
||||||
|
logger.Debug("Tunnel is no longer running, skipping registration")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -394,6 +403,12 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
|||||||
// delay for 500ms to allow for time for the hp to get processed
|
// delay for 500ms to allow for time for the hp to get processed
|
||||||
time.Sleep(500 * time.Millisecond)
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
|
// Check again after sleep in case tunnel was stopped
|
||||||
|
if !o.tunnelRunning {
|
||||||
|
logger.Debug("Tunnel stopped during delay, skipping registration")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
if o.stopRegister == nil {
|
if o.stopRegister == nil {
|
||||||
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch)
|
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch)
|
||||||
o.stopRegister, o.updateRegister = o.websocket.SendMessageInterval("olm/wg/register", map[string]any{
|
o.stopRegister, o.updateRegister = o.websocket.SendMessageInterval("olm/wg/register", map[string]any{
|
||||||
@@ -417,6 +432,12 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
o.websocket.OnTokenUpdate(func(token string, exitNodes []websocket.ExitNode) {
|
o.websocket.OnTokenUpdate(func(token string, exitNodes []websocket.ExitNode) {
|
||||||
|
// Check if tunnel is still running and hole punch manager exists
|
||||||
|
if !o.tunnelRunning || o.holePunchManager == nil {
|
||||||
|
logger.Debug("Tunnel stopped or hole punch manager nil, ignoring token update")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
o.holePunchManager.SetToken(token)
|
o.holePunchManager.SetToken(token)
|
||||||
|
|
||||||
logger.Debug("Got exit nodes for hole punching: %v", exitNodes)
|
logger.Debug("Got exit nodes for hole punching: %v", exitNodes)
|
||||||
@@ -447,6 +468,12 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
o.websocket.OnAuthError(func(statusCode int, message string) {
|
o.websocket.OnAuthError(func(statusCode int, message string) {
|
||||||
|
// Check if tunnel is still running
|
||||||
|
if !o.tunnelRunning {
|
||||||
|
logger.Debug("Tunnel stopped, ignoring auth error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
logger.Error("Authentication error (status %d): %s. Terminating tunnel.", statusCode, message)
|
logger.Error("Authentication error (status %d): %s. Terminating tunnel.", statusCode, message)
|
||||||
o.apiServer.SetTerminated(true)
|
o.apiServer.SetTerminated(true)
|
||||||
o.apiServer.SetConnectionStatus(false)
|
o.apiServer.SetConnectionStatus(false)
|
||||||
@@ -466,6 +493,10 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Indicate that tunnel is starting
|
||||||
|
o.tunnelWg.Add(1)
|
||||||
|
defer o.tunnelWg.Done()
|
||||||
|
|
||||||
// Connect to the WebSocket server
|
// Connect to the WebSocket server
|
||||||
if err := o.websocket.Connect(); err != nil {
|
if err := o.websocket.Connect(); err != nil {
|
||||||
logger.Error("Failed to connect to server: %v", err)
|
logger.Error("Failed to connect to server: %v", err)
|
||||||
@@ -479,6 +510,13 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (o *Olm) Close() {
|
func (o *Olm) Close() {
|
||||||
|
// Stop registration first to prevent it from trying to use closed websocket
|
||||||
|
if o.stopRegister != nil {
|
||||||
|
logger.Debug("Stopping registration interval")
|
||||||
|
o.stopRegister()
|
||||||
|
o.stopRegister = nil
|
||||||
|
}
|
||||||
|
|
||||||
// send a disconnect message to the cloud to show disconnected
|
// send a disconnect message to the cloud to show disconnected
|
||||||
if o.websocket != nil {
|
if o.websocket != nil {
|
||||||
o.websocket.SendMessage("olm/disconnecting", map[string]any{})
|
o.websocket.SendMessage("olm/disconnecting", map[string]any{})
|
||||||
@@ -498,11 +536,6 @@ func (o *Olm) Close() {
|
|||||||
o.holePunchManager = nil
|
o.holePunchManager = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if o.stopRegister != nil {
|
|
||||||
o.stopRegister()
|
|
||||||
o.stopRegister = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close() also calls Stop() internally
|
// Close() also calls Stop() internally
|
||||||
if o.peerManager != nil {
|
if o.peerManager != nil {
|
||||||
o.peerManager.Close()
|
o.peerManager.Close()
|
||||||
@@ -533,6 +566,21 @@ func (o *Olm) Close() {
|
|||||||
logger.Debug("Closing MiddleDevice")
|
logger.Debug("Closing MiddleDevice")
|
||||||
_ = o.middleDev.Close()
|
_ = o.middleDev.Close()
|
||||||
o.middleDev = nil
|
o.middleDev = nil
|
||||||
|
} else if o.tdev != nil {
|
||||||
|
// If middleDev was never created but tdev exists, close it directly
|
||||||
|
logger.Debug("Closing TUN device directly (no MiddleDevice)")
|
||||||
|
_ = o.tdev.Close()
|
||||||
|
o.tdev = nil
|
||||||
|
} else if o.tunnelConfig.FileDescriptorTun != 0 {
|
||||||
|
// If we never created a device from the FD, close it explicitly
|
||||||
|
// This can happen if tunnel is stopped during registration before handleConnect
|
||||||
|
logger.Debug("Closing unused TUN file descriptor %d", o.tunnelConfig.FileDescriptorTun)
|
||||||
|
if err := closeFD(o.tunnelConfig.FileDescriptorTun); err != nil {
|
||||||
|
logger.Error("Failed to close TUN file descriptor: %v", err)
|
||||||
|
} else {
|
||||||
|
logger.Info("Closed unused TUN file descriptor")
|
||||||
|
}
|
||||||
|
o.tunnelConfig.FileDescriptorTun = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now close WireGuard device - its TUN reader should have exited by now
|
// Now close WireGuard device - its TUN reader should have exited by now
|
||||||
@@ -565,20 +613,24 @@ func (o *Olm) StopTunnel() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Reset the running state BEFORE cleanup to prevent callbacks from accessing nil pointers
|
||||||
|
o.registered = false
|
||||||
|
o.tunnelRunning = false
|
||||||
|
|
||||||
// Cancel the tunnel context if it exists
|
// Cancel the tunnel context if it exists
|
||||||
if o.tunnelCancel != nil {
|
if o.tunnelCancel != nil {
|
||||||
|
logger.Debug("Cancelling tunnel context")
|
||||||
o.tunnelCancel()
|
o.tunnelCancel()
|
||||||
// Give it a moment to clean up
|
|
||||||
time.Sleep(200 * time.Millisecond)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Wait for the tunnel goroutine to complete
|
||||||
|
logger.Debug("Waiting for tunnel goroutine to finish")
|
||||||
|
o.tunnelWg.Wait()
|
||||||
|
logger.Debug("Tunnel goroutine finished")
|
||||||
|
|
||||||
// Close() will handle sending disconnect message and closing websocket
|
// Close() will handle sending disconnect message and closing websocket
|
||||||
o.Close()
|
o.Close()
|
||||||
|
|
||||||
// Reset the connected state
|
|
||||||
o.connected = false
|
|
||||||
o.tunnelRunning = false
|
|
||||||
|
|
||||||
// Update API server status
|
// Update API server status
|
||||||
o.apiServer.SetConnectionStatus(false)
|
o.apiServer.SetConnectionStatus(false)
|
||||||
o.apiServer.SetRegistered(false)
|
o.apiServer.SetRegistered(false)
|
||||||
@@ -686,9 +738,6 @@ func (o *Olm) SetPowerMode(mode string) error {
|
|||||||
|
|
||||||
logger.Info("Switching to low power mode")
|
logger.Info("Switching to low power mode")
|
||||||
|
|
||||||
// Mark as disconnected so we re-register on reconnect
|
|
||||||
o.connected = false
|
|
||||||
|
|
||||||
// Update API server connection status
|
// Update API server connection status
|
||||||
if o.apiServer != nil {
|
if o.apiServer != nil {
|
||||||
o.apiServer.SetConnectionStatus(false)
|
o.apiServer.SetConnectionStatus(false)
|
||||||
|
|||||||
10
olm/olm_unix.go
Normal file
10
olm/olm_unix.go
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package olm
|
||||||
|
|
||||||
|
import "syscall"
|
||||||
|
|
||||||
|
// closeFD closes a file descriptor in a platform-specific way
|
||||||
|
func closeFD(fd uint32) error {
|
||||||
|
return syscall.Close(int(fd))
|
||||||
|
}
|
||||||
10
olm/olm_windows.go
Normal file
10
olm/olm_windows.go
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package olm
|
||||||
|
|
||||||
|
import "syscall"
|
||||||
|
|
||||||
|
// closeFD closes a file descriptor in a platform-specific way
|
||||||
|
func closeFD(fd uint32) error {
|
||||||
|
return syscall.Close(syscall.Handle(fd))
|
||||||
|
}
|
||||||
24
olm/peer.go
24
olm/peer.go
@@ -14,6 +14,12 @@ import (
|
|||||||
func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) {
|
func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) {
|
||||||
logger.Debug("Received add-peer message: %v", msg.Data)
|
logger.Debug("Received add-peer message: %v", msg.Data)
|
||||||
|
|
||||||
|
// Check if tunnel is still running
|
||||||
|
if !o.tunnelRunning {
|
||||||
|
logger.Debug("Tunnel stopped, ignoring add-peer message")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if o.stopPeerSend != nil {
|
if o.stopPeerSend != nil {
|
||||||
o.stopPeerSend()
|
o.stopPeerSend()
|
||||||
o.stopPeerSend = nil
|
o.stopPeerSend = nil
|
||||||
@@ -44,6 +50,12 @@ func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) {
|
|||||||
func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) {
|
func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) {
|
||||||
logger.Debug("Received remove-peer message: %v", msg.Data)
|
logger.Debug("Received remove-peer message: %v", msg.Data)
|
||||||
|
|
||||||
|
// Check if tunnel is still running
|
||||||
|
if !o.tunnelRunning {
|
||||||
|
logger.Debug("Tunnel stopped, ignoring remove-peer message")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
jsonData, err := json.Marshal(msg.Data)
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Error marshaling data: %v", err)
|
logger.Error("Error marshaling data: %v", err)
|
||||||
@@ -75,6 +87,12 @@ func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) {
|
|||||||
func (o *Olm) handleWgPeerUpdate(msg websocket.WSMessage) {
|
func (o *Olm) handleWgPeerUpdate(msg websocket.WSMessage) {
|
||||||
logger.Debug("Received update-peer message: %v", msg.Data)
|
logger.Debug("Received update-peer message: %v", msg.Data)
|
||||||
|
|
||||||
|
// Check if tunnel is still running
|
||||||
|
if !o.tunnelRunning {
|
||||||
|
logger.Debug("Tunnel stopped, ignoring update-peer message")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
jsonData, err := json.Marshal(msg.Data)
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Error marshaling data: %v", err)
|
logger.Error("Error marshaling data: %v", err)
|
||||||
@@ -199,6 +217,12 @@ func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) {
|
|||||||
func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
|
func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
|
||||||
logger.Debug("Received peer-handshake message: %v", msg.Data)
|
logger.Debug("Received peer-handshake message: %v", msg.Data)
|
||||||
|
|
||||||
|
// Check if tunnel is still running
|
||||||
|
if !o.tunnelRunning {
|
||||||
|
logger.Debug("Tunnel stopped, ignoring peer-handshake message")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
jsonData, err := json.Marshal(msg.Data)
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Error marshaling handshake data: %v", err)
|
logger.Error("Error marshaling handshake data: %v", err)
|
||||||
|
|||||||
Reference in New Issue
Block a user