mirror of
https://github.com/fosrl/olm.git
synced 2026-03-07 03:06:44 +00:00
Compare commits
15 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f2d0e6a14c | ||
|
|
ae88766d85 | ||
|
|
9ae49e36d5 | ||
|
|
5ca4825800 | ||
|
|
809dbe77de | ||
|
|
c67c2a60a1 | ||
|
|
051c0fdfd8 | ||
|
|
e7507e0837 | ||
|
|
21b66fbb34 | ||
|
|
9c0e37eddb | ||
|
|
5527bff671 | ||
|
|
af973b2440 | ||
|
|
dd9bff9a4b | ||
|
|
1be5e454ba | ||
|
|
4a25a0d413 |
62
api/api.go
62
api/api.go
@@ -78,6 +78,13 @@ type MetadataChangeRequest struct {
|
||||
Postures map[string]any `json:"postures"`
|
||||
}
|
||||
|
||||
// JITConnectionRequest defines the structure for a dynamic Just-In-Time connection request.
|
||||
// Either SiteID or ResourceID must be provided (but not necessarily both).
|
||||
type JITConnectionRequest struct {
|
||||
Site string `json:"site,omitempty"`
|
||||
Resource string `json:"resource,omitempty"`
|
||||
}
|
||||
|
||||
// API represents the HTTP server and its state
|
||||
type API struct {
|
||||
addr string
|
||||
@@ -92,6 +99,7 @@ type API struct {
|
||||
onExit func() error
|
||||
onRebind func() error
|
||||
onPowerMode func(PowerModeRequest) error
|
||||
onJITConnect func(JITConnectionRequest) error
|
||||
|
||||
statusMu sync.RWMutex
|
||||
peerStatuses map[int]*PeerStatus
|
||||
@@ -143,6 +151,7 @@ func (s *API) SetHandlers(
|
||||
onExit func() error,
|
||||
onRebind func() error,
|
||||
onPowerMode func(PowerModeRequest) error,
|
||||
onJITConnect func(JITConnectionRequest) error,
|
||||
) {
|
||||
s.onConnect = onConnect
|
||||
s.onSwitchOrg = onSwitchOrg
|
||||
@@ -151,6 +160,7 @@ func (s *API) SetHandlers(
|
||||
s.onExit = onExit
|
||||
s.onRebind = onRebind
|
||||
s.onPowerMode = onPowerMode
|
||||
s.onJITConnect = onJITConnect
|
||||
}
|
||||
|
||||
// Start starts the HTTP server
|
||||
@@ -169,6 +179,7 @@ func (s *API) Start() error {
|
||||
mux.HandleFunc("/health", s.handleHealth)
|
||||
mux.HandleFunc("/rebind", s.handleRebind)
|
||||
mux.HandleFunc("/power-mode", s.handlePowerMode)
|
||||
mux.HandleFunc("/jit-connect", s.handleJITConnect)
|
||||
|
||||
s.server = &http.Server{
|
||||
Handler: mux,
|
||||
@@ -272,9 +283,6 @@ func (s *API) SetConnectionStatus(isConnected bool) {
|
||||
|
||||
if isConnected {
|
||||
s.connectedAt = time.Now()
|
||||
} else {
|
||||
// Clear peer statuses when disconnected
|
||||
s.peerStatuses = make(map[int]*PeerStatus)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -636,6 +644,54 @@ func (s *API) handleRebind(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
// handleJITConnect handles the /jit-connect endpoint.
|
||||
// It initiates a dynamic Just-In-Time connection to a site identified by either
|
||||
// a site or a resource. Exactly one of the two must be provided.
|
||||
func (s *API) handleJITConnect(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req JITConnectionRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate that exactly one of site or resource is provided
|
||||
if req.Site == "" && req.Resource == "" {
|
||||
http.Error(w, "Missing required field: either site or resource must be provided", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if req.Site != "" && req.Resource != "" {
|
||||
http.Error(w, "Ambiguous request: provide either site or resource, not both", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Site != "" {
|
||||
logger.Info("Received JIT connection request via API: site=%s", req.Site)
|
||||
} else {
|
||||
logger.Info("Received JIT connection request via API: resource=%s", req.Resource)
|
||||
}
|
||||
|
||||
if s.onJITConnect != nil {
|
||||
if err := s.onJITConnect(req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("JIT connection failed: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
http.Error(w, "JIT connect handler not configured", http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "JIT connection request accepted",
|
||||
})
|
||||
}
|
||||
|
||||
// handlePowerMode handles the /power-mode endpoint
|
||||
// This allows changing the power mode between "normal" and "low"
|
||||
func (s *API) handlePowerMode(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
@@ -89,7 +89,7 @@ func DefaultConfig() *OlmConfig {
|
||||
PingInterval: "3s",
|
||||
PingTimeout: "5s",
|
||||
DisableHolepunch: false,
|
||||
OverrideDNS: false,
|
||||
OverrideDNS: true,
|
||||
TunnelDNS: false,
|
||||
// DoNotCreateNewClient: false,
|
||||
sources: make(map[string]string),
|
||||
|
||||
@@ -380,7 +380,7 @@ func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clie
|
||||
|
||||
// Check if we have local records for this query
|
||||
var response *dns.Msg
|
||||
if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA {
|
||||
if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA || question.Qtype == dns.TypePTR {
|
||||
response = p.checkLocalRecords(msg, question)
|
||||
}
|
||||
|
||||
@@ -410,6 +410,34 @@ func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clie
|
||||
|
||||
// checkLocalRecords checks if we have local records for the query
|
||||
func (p *DNSProxy) checkLocalRecords(query *dns.Msg, question dns.Question) *dns.Msg {
|
||||
// Handle PTR queries
|
||||
if question.Qtype == dns.TypePTR {
|
||||
if ptrDomain, ok := p.recordStore.GetPTRRecord(question.Name); ok {
|
||||
logger.Debug("Found local PTR record for %s -> %s", question.Name, ptrDomain)
|
||||
|
||||
// Create response message
|
||||
response := new(dns.Msg)
|
||||
response.SetReply(query)
|
||||
response.Authoritative = true
|
||||
|
||||
// Add PTR answer record
|
||||
rr := &dns.PTR{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: question.Name,
|
||||
Rrtype: dns.TypePTR,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 300, // 5 minutes
|
||||
},
|
||||
Ptr: ptrDomain,
|
||||
}
|
||||
response.Answer = append(response.Answer, rr)
|
||||
|
||||
return response
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle A and AAAA queries
|
||||
var recordType RecordType
|
||||
if question.Qtype == dns.TypeA {
|
||||
recordType = RecordTypeA
|
||||
@@ -419,19 +447,20 @@ func (p *DNSProxy) checkLocalRecords(query *dns.Msg, question dns.Question) *dns
|
||||
return nil
|
||||
}
|
||||
|
||||
ips := p.recordStore.GetRecords(question.Name, recordType)
|
||||
if len(ips) == 0 {
|
||||
ips, exists := p.recordStore.GetRecords(question.Name, recordType)
|
||||
if !exists {
|
||||
// Domain not found in local records, forward to upstream
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Debug("Found %d local record(s) for %s", len(ips), question.Name)
|
||||
|
||||
// Create response message
|
||||
// Create response message (NODATA if no records, otherwise with answers)
|
||||
response := new(dns.Msg)
|
||||
response.SetReply(query)
|
||||
response.Authoritative = true
|
||||
|
||||
// Add answer records
|
||||
// Add answer records (loop is a no-op if ips is empty)
|
||||
for _, ip := range ips {
|
||||
var rr dns.RR
|
||||
if question.Qtype == dns.TypeA {
|
||||
@@ -702,8 +731,9 @@ func (p *DNSProxy) RemoveDNSRecord(domain string, ip net.IP) {
|
||||
p.recordStore.RemoveRecord(domain, ip)
|
||||
}
|
||||
|
||||
// GetDNSRecords returns all IP addresses for a domain and record type
|
||||
func (p *DNSProxy) GetDNSRecords(domain string, recordType RecordType) []net.IP {
|
||||
// GetDNSRecords returns all IP addresses for a domain and record type.
|
||||
// The second return value indicates whether the domain exists.
|
||||
func (p *DNSProxy) GetDNSRecords(domain string, recordType RecordType) ([]net.IP, bool) {
|
||||
return p.recordStore.GetRecords(domain, recordType)
|
||||
}
|
||||
|
||||
|
||||
178
dns/dns_proxy_test.go
Normal file
178
dns/dns_proxy_test.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestCheckLocalRecordsNODATAForAAAA(t *testing.T) {
|
||||
proxy := &DNSProxy{
|
||||
recordStore: NewDNSRecordStore(),
|
||||
}
|
||||
|
||||
// Add an A record for a domain (no AAAA record)
|
||||
ip := net.ParseIP("10.0.0.1")
|
||||
err := proxy.recordStore.AddRecord("myservice.internal", ip)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add A record: %v", err)
|
||||
}
|
||||
|
||||
// Query AAAA for domain with only A record - should return NODATA
|
||||
query := new(dns.Msg)
|
||||
query.SetQuestion("myservice.internal.", dns.TypeAAAA)
|
||||
response := proxy.checkLocalRecords(query, query.Question[0])
|
||||
|
||||
if response == nil {
|
||||
t.Fatal("Expected NODATA response, got nil (would forward to upstream)")
|
||||
}
|
||||
if response.Rcode != dns.RcodeSuccess {
|
||||
t.Errorf("Expected Rcode NOERROR (0), got %d", response.Rcode)
|
||||
}
|
||||
if len(response.Answer) != 0 {
|
||||
t.Errorf("Expected empty answer section for NODATA, got %d answers", len(response.Answer))
|
||||
}
|
||||
if !response.Authoritative {
|
||||
t.Error("Expected response to be authoritative")
|
||||
}
|
||||
|
||||
// Query A for same domain - should return the record
|
||||
query = new(dns.Msg)
|
||||
query.SetQuestion("myservice.internal.", dns.TypeA)
|
||||
response = proxy.checkLocalRecords(query, query.Question[0])
|
||||
|
||||
if response == nil {
|
||||
t.Fatal("Expected response with A record, got nil")
|
||||
}
|
||||
if len(response.Answer) != 1 {
|
||||
t.Fatalf("Expected 1 answer, got %d", len(response.Answer))
|
||||
}
|
||||
aRecord, ok := response.Answer[0].(*dns.A)
|
||||
if !ok {
|
||||
t.Fatal("Expected A record in answer")
|
||||
}
|
||||
if !aRecord.A.Equal(ip.To4()) {
|
||||
t.Errorf("Expected IP %v, got %v", ip.To4(), aRecord.A)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckLocalRecordsNODATAForA(t *testing.T) {
|
||||
proxy := &DNSProxy{
|
||||
recordStore: NewDNSRecordStore(),
|
||||
}
|
||||
|
||||
// Add an AAAA record for a domain (no A record)
|
||||
ip := net.ParseIP("2001:db8::1")
|
||||
err := proxy.recordStore.AddRecord("ipv6only.internal", ip)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add AAAA record: %v", err)
|
||||
}
|
||||
|
||||
// Query A for domain with only AAAA record - should return NODATA
|
||||
query := new(dns.Msg)
|
||||
query.SetQuestion("ipv6only.internal.", dns.TypeA)
|
||||
response := proxy.checkLocalRecords(query, query.Question[0])
|
||||
|
||||
if response == nil {
|
||||
t.Fatal("Expected NODATA response, got nil")
|
||||
}
|
||||
if response.Rcode != dns.RcodeSuccess {
|
||||
t.Errorf("Expected Rcode NOERROR (0), got %d", response.Rcode)
|
||||
}
|
||||
if len(response.Answer) != 0 {
|
||||
t.Errorf("Expected empty answer section, got %d answers", len(response.Answer))
|
||||
}
|
||||
if !response.Authoritative {
|
||||
t.Error("Expected response to be authoritative")
|
||||
}
|
||||
|
||||
// Query AAAA for same domain - should return the record
|
||||
query = new(dns.Msg)
|
||||
query.SetQuestion("ipv6only.internal.", dns.TypeAAAA)
|
||||
response = proxy.checkLocalRecords(query, query.Question[0])
|
||||
|
||||
if response == nil {
|
||||
t.Fatal("Expected response with AAAA record, got nil")
|
||||
}
|
||||
if len(response.Answer) != 1 {
|
||||
t.Fatalf("Expected 1 answer, got %d", len(response.Answer))
|
||||
}
|
||||
aaaaRecord, ok := response.Answer[0].(*dns.AAAA)
|
||||
if !ok {
|
||||
t.Fatal("Expected AAAA record in answer")
|
||||
}
|
||||
if !aaaaRecord.AAAA.Equal(ip) {
|
||||
t.Errorf("Expected IP %v, got %v", ip, aaaaRecord.AAAA)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckLocalRecordsNonExistentDomain(t *testing.T) {
|
||||
proxy := &DNSProxy{
|
||||
recordStore: NewDNSRecordStore(),
|
||||
}
|
||||
|
||||
// Add a record so the store isn't empty
|
||||
err := proxy.recordStore.AddRecord("exists.internal", net.ParseIP("10.0.0.1"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add record: %v", err)
|
||||
}
|
||||
|
||||
// Query A for non-existent domain - should return nil (forward to upstream)
|
||||
query := new(dns.Msg)
|
||||
query.SetQuestion("unknown.internal.", dns.TypeA)
|
||||
response := proxy.checkLocalRecords(query, query.Question[0])
|
||||
|
||||
if response != nil {
|
||||
t.Error("Expected nil for non-existent domain, got response")
|
||||
}
|
||||
|
||||
// Query AAAA for non-existent domain - should also return nil
|
||||
query = new(dns.Msg)
|
||||
query.SetQuestion("unknown.internal.", dns.TypeAAAA)
|
||||
response = proxy.checkLocalRecords(query, query.Question[0])
|
||||
|
||||
if response != nil {
|
||||
t.Error("Expected nil for non-existent domain, got response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckLocalRecordsNODATAWildcard(t *testing.T) {
|
||||
proxy := &DNSProxy{
|
||||
recordStore: NewDNSRecordStore(),
|
||||
}
|
||||
|
||||
// Add a wildcard A record (no AAAA)
|
||||
ip := net.ParseIP("10.0.0.1")
|
||||
err := proxy.recordStore.AddRecord("*.wildcard.internal", ip)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add wildcard A record: %v", err)
|
||||
}
|
||||
|
||||
// Query AAAA for wildcard-matched domain - should return NODATA
|
||||
query := new(dns.Msg)
|
||||
query.SetQuestion("host.wildcard.internal.", dns.TypeAAAA)
|
||||
response := proxy.checkLocalRecords(query, query.Question[0])
|
||||
|
||||
if response == nil {
|
||||
t.Fatal("Expected NODATA response for wildcard match, got nil")
|
||||
}
|
||||
if response.Rcode != dns.RcodeSuccess {
|
||||
t.Errorf("Expected Rcode NOERROR (0), got %d", response.Rcode)
|
||||
}
|
||||
if len(response.Answer) != 0 {
|
||||
t.Errorf("Expected empty answer section, got %d answers", len(response.Answer))
|
||||
}
|
||||
|
||||
// Query A for wildcard-matched domain - should return the record
|
||||
query = new(dns.Msg)
|
||||
query.SetQuestion("host.wildcard.internal.", dns.TypeA)
|
||||
response = proxy.checkLocalRecords(query, query.Question[0])
|
||||
|
||||
if response == nil {
|
||||
t.Fatal("Expected response with A record, got nil")
|
||||
}
|
||||
if len(response.Answer) != 1 {
|
||||
t.Fatalf("Expected 1 answer, got %d", len(response.Answer))
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -14,24 +15,30 @@ type RecordType uint16
|
||||
const (
|
||||
RecordTypeA RecordType = RecordType(dns.TypeA)
|
||||
RecordTypeAAAA RecordType = RecordType(dns.TypeAAAA)
|
||||
RecordTypePTR RecordType = RecordType(dns.TypePTR)
|
||||
)
|
||||
|
||||
// DNSRecordStore manages local DNS records for A and AAAA queries
|
||||
// recordSet holds A and AAAA records for a single domain or wildcard pattern
|
||||
type recordSet struct {
|
||||
A []net.IP
|
||||
AAAA []net.IP
|
||||
}
|
||||
|
||||
// DNSRecordStore manages local DNS records for A, AAAA, and PTR queries.
|
||||
// Exact domains are stored in a map; wildcard patterns are in a separate map.
|
||||
type DNSRecordStore struct {
|
||||
mu sync.RWMutex
|
||||
aRecords map[string][]net.IP // domain -> list of IPv4 addresses
|
||||
aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses
|
||||
aWildcards map[string][]net.IP // wildcard pattern -> list of IPv4 addresses
|
||||
aaaaWildcards map[string][]net.IP // wildcard pattern -> list of IPv6 addresses
|
||||
exact map[string]*recordSet // normalized FQDN -> A/AAAA records
|
||||
wildcards map[string]*recordSet // wildcard pattern -> A/AAAA records
|
||||
ptrRecords map[string]string // IP address string -> domain name
|
||||
}
|
||||
|
||||
// NewDNSRecordStore creates a new DNS record store
|
||||
func NewDNSRecordStore() *DNSRecordStore {
|
||||
return &DNSRecordStore{
|
||||
aRecords: make(map[string][]net.IP),
|
||||
aaaaRecords: make(map[string][]net.IP),
|
||||
aWildcards: make(map[string][]net.IP),
|
||||
aaaaWildcards: make(map[string][]net.IP),
|
||||
exact: make(map[string]*recordSet),
|
||||
wildcards: make(map[string]*recordSet),
|
||||
ptrRecords: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,10 +46,52 @@ func NewDNSRecordStore() *DNSRecordStore {
|
||||
// domain should be in FQDN format (e.g., "example.com.")
|
||||
// domain can contain wildcards: * (0+ chars) and ? (exactly 1 char)
|
||||
// ip should be a valid IPv4 or IPv6 address
|
||||
// Automatically adds a corresponding PTR record for non-wildcard domains
|
||||
func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if len(domain) == 0 || domain[len(domain)-1] != '.' {
|
||||
domain = domain + "."
|
||||
}
|
||||
domain = strings.ToLower(dns.Fqdn(domain))
|
||||
isWildcard := strings.ContainsAny(domain, "*?")
|
||||
|
||||
isV4 := ip.To4() != nil
|
||||
if !isV4 && ip.To16() == nil {
|
||||
return &net.ParseError{Type: "IP address", Text: ip.String()}
|
||||
}
|
||||
|
||||
// Choose the appropriate map based on whether this is a wildcard
|
||||
m := s.exact
|
||||
if isWildcard {
|
||||
m = s.wildcards
|
||||
}
|
||||
|
||||
if m[domain] == nil {
|
||||
m[domain] = &recordSet{}
|
||||
}
|
||||
rs := m[domain]
|
||||
if isV4 {
|
||||
rs.A = append(rs.A, ip)
|
||||
} else {
|
||||
rs.AAAA = append(rs.AAAA, ip)
|
||||
}
|
||||
|
||||
// Add PTR record for non-wildcard domains
|
||||
if !isWildcard {
|
||||
s.ptrRecords[ip.String()] = domain
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddPTRRecord adds a PTR record mapping an IP address to a domain name
|
||||
// ip should be a valid IPv4 or IPv6 address
|
||||
// domain should be in FQDN format (e.g., "example.com.")
|
||||
func (s *DNSRecordStore) AddPTRRecord(ip net.IP, domain string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Ensure domain ends with a dot (FQDN format)
|
||||
if len(domain) == 0 || domain[len(domain)-1] != '.' {
|
||||
domain = domain + "."
|
||||
@@ -51,151 +100,155 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
|
||||
// Normalize domain to lowercase FQDN
|
||||
domain = strings.ToLower(dns.Fqdn(domain))
|
||||
|
||||
// Check if domain contains wildcards
|
||||
isWildcard := strings.ContainsAny(domain, "*?")
|
||||
|
||||
if ip.To4() != nil {
|
||||
// IPv4 address
|
||||
if isWildcard {
|
||||
s.aWildcards[domain] = append(s.aWildcards[domain], ip)
|
||||
} else {
|
||||
s.aRecords[domain] = append(s.aRecords[domain], ip)
|
||||
}
|
||||
} else if ip.To16() != nil {
|
||||
// IPv6 address
|
||||
if isWildcard {
|
||||
s.aaaaWildcards[domain] = append(s.aaaaWildcards[domain], ip)
|
||||
} else {
|
||||
s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip)
|
||||
}
|
||||
} else {
|
||||
return &net.ParseError{Type: "IP address", Text: ip.String()}
|
||||
}
|
||||
// Store PTR record using IP string as key
|
||||
s.ptrRecords[ip.String()] = domain
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveRecord removes a specific DNS record mapping
|
||||
// If ip is nil, removes all records for the domain (including wildcards)
|
||||
// Automatically removes corresponding PTR records for non-wildcard domains
|
||||
func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Ensure domain ends with a dot (FQDN format)
|
||||
if len(domain) == 0 || domain[len(domain)-1] != '.' {
|
||||
domain = domain + "."
|
||||
}
|
||||
|
||||
// Normalize domain to lowercase FQDN
|
||||
domain = strings.ToLower(dns.Fqdn(domain))
|
||||
|
||||
// Check if domain contains wildcards
|
||||
isWildcard := strings.ContainsAny(domain, "*?")
|
||||
|
||||
if ip == nil {
|
||||
// Remove all records for this domain
|
||||
// Choose the appropriate map
|
||||
m := s.exact
|
||||
if isWildcard {
|
||||
delete(s.aWildcards, domain)
|
||||
delete(s.aaaaWildcards, domain)
|
||||
} else {
|
||||
delete(s.aRecords, domain)
|
||||
delete(s.aaaaRecords, domain)
|
||||
m = s.wildcards
|
||||
}
|
||||
|
||||
rs := m[domain]
|
||||
if rs == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if ip == nil {
|
||||
// Remove all records for this domain
|
||||
if !isWildcard {
|
||||
for _, ipAddr := range rs.A {
|
||||
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
|
||||
delete(s.ptrRecords, ipAddr.String())
|
||||
}
|
||||
}
|
||||
for _, ipAddr := range rs.AAAA {
|
||||
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
|
||||
delete(s.ptrRecords, ipAddr.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
delete(m, domain)
|
||||
return
|
||||
}
|
||||
|
||||
// Remove specific IP
|
||||
if ip.To4() != nil {
|
||||
// Remove specific IPv4 address
|
||||
if isWildcard {
|
||||
if ips, ok := s.aWildcards[domain]; ok {
|
||||
s.aWildcards[domain] = removeIP(ips, ip)
|
||||
if len(s.aWildcards[domain]) == 0 {
|
||||
delete(s.aWildcards, domain)
|
||||
rs.A = removeIP(rs.A, ip)
|
||||
if !isWildcard {
|
||||
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
|
||||
delete(s.ptrRecords, ip.String())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if ips, ok := s.aRecords[domain]; ok {
|
||||
s.aRecords[domain] = removeIP(ips, ip)
|
||||
if len(s.aRecords[domain]) == 0 {
|
||||
delete(s.aRecords, domain)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if ip.To16() != nil {
|
||||
// Remove specific IPv6 address
|
||||
if isWildcard {
|
||||
if ips, ok := s.aaaaWildcards[domain]; ok {
|
||||
s.aaaaWildcards[domain] = removeIP(ips, ip)
|
||||
if len(s.aaaaWildcards[domain]) == 0 {
|
||||
delete(s.aaaaWildcards, domain)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if ips, ok := s.aaaaRecords[domain]; ok {
|
||||
s.aaaaRecords[domain] = removeIP(ips, ip)
|
||||
if len(s.aaaaRecords[domain]) == 0 {
|
||||
delete(s.aaaaRecords, domain)
|
||||
rs.AAAA = removeIP(rs.AAAA, ip)
|
||||
if !isWildcard {
|
||||
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
|
||||
delete(s.ptrRecords, ip.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up empty record sets
|
||||
if len(rs.A) == 0 && len(rs.AAAA) == 0 {
|
||||
delete(m, domain)
|
||||
}
|
||||
}
|
||||
|
||||
// GetRecords returns all IP addresses for a domain and record type
|
||||
// First checks for exact matches, then checks wildcard patterns
|
||||
func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP {
|
||||
// RemovePTRRecord removes a PTR record for an IP address
|
||||
func (s *DNSRecordStore) RemovePTRRecord(ip net.IP) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
delete(s.ptrRecords, ip.String())
|
||||
}
|
||||
|
||||
// GetRecords returns all IP addresses for a domain and record type.
|
||||
// The second return value indicates whether the domain exists at all
|
||||
// (true = domain exists, use NODATA if no records; false = NXDOMAIN).
|
||||
func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) ([]net.IP, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// Normalize domain to lowercase FQDN
|
||||
domain = strings.ToLower(dns.Fqdn(domain))
|
||||
|
||||
// Check exact match first
|
||||
if rs, exists := s.exact[domain]; exists {
|
||||
var ips []net.IP
|
||||
if recordType == RecordTypeA {
|
||||
ips = rs.A
|
||||
} else {
|
||||
ips = rs.AAAA
|
||||
}
|
||||
if len(ips) > 0 {
|
||||
out := make([]net.IP, len(ips))
|
||||
copy(out, ips)
|
||||
return out, true
|
||||
}
|
||||
// Domain exists but no records of this type
|
||||
return nil, true
|
||||
}
|
||||
|
||||
// Check wildcard matches
|
||||
var records []net.IP
|
||||
switch recordType {
|
||||
case RecordTypeA:
|
||||
// Check exact match first
|
||||
if ips, ok := s.aRecords[domain]; ok {
|
||||
// Return a copy to prevent external modifications
|
||||
records = make([]net.IP, len(ips))
|
||||
copy(records, ips)
|
||||
return records
|
||||
matched := false
|
||||
for pattern, rs := range s.wildcards {
|
||||
if !matchWildcard(pattern, domain) {
|
||||
continue
|
||||
}
|
||||
// Check wildcard patterns
|
||||
for pattern, ips := range s.aWildcards {
|
||||
if matchWildcard(pattern, domain) {
|
||||
records = append(records, ips...)
|
||||
}
|
||||
}
|
||||
if len(records) > 0 {
|
||||
// Return a copy
|
||||
result := make([]net.IP, len(records))
|
||||
copy(result, records)
|
||||
return result
|
||||
}
|
||||
|
||||
case RecordTypeAAAA:
|
||||
// Check exact match first
|
||||
if ips, ok := s.aaaaRecords[domain]; ok {
|
||||
// Return a copy to prevent external modifications
|
||||
records = make([]net.IP, len(ips))
|
||||
copy(records, ips)
|
||||
return records
|
||||
}
|
||||
// Check wildcard patterns
|
||||
for pattern, ips := range s.aaaaWildcards {
|
||||
if matchWildcard(pattern, domain) {
|
||||
records = append(records, ips...)
|
||||
}
|
||||
}
|
||||
if len(records) > 0 {
|
||||
// Return a copy
|
||||
result := make([]net.IP, len(records))
|
||||
copy(result, records)
|
||||
return result
|
||||
matched = true
|
||||
if recordType == RecordTypeA {
|
||||
records = append(records, rs.A...)
|
||||
} else {
|
||||
records = append(records, rs.AAAA...)
|
||||
}
|
||||
}
|
||||
|
||||
return records
|
||||
if !matched {
|
||||
return nil, false
|
||||
}
|
||||
if len(records) == 0 {
|
||||
return nil, true
|
||||
}
|
||||
out := make([]net.IP, len(records))
|
||||
copy(out, records)
|
||||
return out, true
|
||||
}
|
||||
|
||||
// GetPTRRecord returns the domain name for a PTR record query
|
||||
// domain should be in reverse DNS format (e.g., "1.0.0.127.in-addr.arpa.")
|
||||
func (s *DNSRecordStore) GetPTRRecord(domain string) (string, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// Convert reverse DNS format to IP address
|
||||
ip := reverseDNSToIP(domain)
|
||||
if ip == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
// Look up the PTR record
|
||||
if ptrDomain, ok := s.ptrRecords[ip.String()]; ok {
|
||||
return ptrDomain, true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
// HasRecord checks if a domain has any records of the specified type
|
||||
@@ -204,46 +257,56 @@ func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// Normalize domain to lowercase FQDN
|
||||
domain = strings.ToLower(dns.Fqdn(domain))
|
||||
|
||||
switch recordType {
|
||||
case RecordTypeA:
|
||||
// Check exact match
|
||||
if _, ok := s.aRecords[domain]; ok {
|
||||
if rs, exists := s.exact[domain]; exists {
|
||||
if recordType == RecordTypeA && len(rs.A) > 0 {
|
||||
return true
|
||||
}
|
||||
// Check wildcard patterns
|
||||
for pattern := range s.aWildcards {
|
||||
if matchWildcard(pattern, domain) {
|
||||
if recordType == RecordTypeAAAA && len(rs.AAAA) > 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
case RecordTypeAAAA:
|
||||
// Check exact match
|
||||
if _, ok := s.aaaaRecords[domain]; ok {
|
||||
return true
|
||||
}
|
||||
// Check wildcard patterns
|
||||
for pattern := range s.aaaaWildcards {
|
||||
if matchWildcard(pattern, domain) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check wildcard matches
|
||||
for pattern, rs := range s.wildcards {
|
||||
if !matchWildcard(pattern, domain) {
|
||||
continue
|
||||
}
|
||||
if recordType == RecordTypeA && len(rs.A) > 0 {
|
||||
return true
|
||||
}
|
||||
if recordType == RecordTypeAAAA && len(rs.AAAA) > 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HasPTRRecord checks if a PTR record exists for the given reverse DNS domain
|
||||
func (s *DNSRecordStore) HasPTRRecord(domain string) bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// Convert reverse DNS format to IP address
|
||||
ip := reverseDNSToIP(domain)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
_, ok := s.ptrRecords[ip.String()]
|
||||
return ok
|
||||
}
|
||||
|
||||
// Clear removes all records from the store
|
||||
func (s *DNSRecordStore) Clear() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.aRecords = make(map[string][]net.IP)
|
||||
s.aaaaRecords = make(map[string][]net.IP)
|
||||
s.aWildcards = make(map[string][]net.IP)
|
||||
s.aaaaWildcards = make(map[string][]net.IP)
|
||||
s.exact = make(map[string]*recordSet)
|
||||
s.wildcards = make(map[string]*recordSet)
|
||||
s.ptrRecords = make(map[string]string)
|
||||
}
|
||||
|
||||
// removeIP is a helper function to remove a specific IP from a slice
|
||||
@@ -323,3 +386,75 @@ func matchWildcardInternal(pattern, domain string, pi, di int) bool {
|
||||
|
||||
return matchWildcardInternal(pattern, domain, pi+1, di+1)
|
||||
}
|
||||
|
||||
// reverseDNSToIP converts a reverse DNS query name to an IP address
|
||||
// Supports both IPv4 (in-addr.arpa) and IPv6 (ip6.arpa) formats
|
||||
func reverseDNSToIP(domain string) net.IP {
|
||||
// Normalize to lowercase and ensure FQDN
|
||||
domain = strings.ToLower(dns.Fqdn(domain))
|
||||
|
||||
// Check for IPv4 reverse DNS (in-addr.arpa)
|
||||
if strings.HasSuffix(domain, ".in-addr.arpa.") {
|
||||
// Remove the suffix
|
||||
ipPart := strings.TrimSuffix(domain, ".in-addr.arpa.")
|
||||
// Split by dots and reverse
|
||||
parts := strings.Split(ipPart, ".")
|
||||
if len(parts) != 4 {
|
||||
return nil
|
||||
}
|
||||
// Reverse the octets
|
||||
reversed := make([]string, 4)
|
||||
for i := 0; i < 4; i++ {
|
||||
reversed[i] = parts[3-i]
|
||||
}
|
||||
// Parse as IP
|
||||
return net.ParseIP(strings.Join(reversed, "."))
|
||||
}
|
||||
|
||||
// Check for IPv6 reverse DNS (ip6.arpa)
|
||||
if strings.HasSuffix(domain, ".ip6.arpa.") {
|
||||
// Remove the suffix
|
||||
ipPart := strings.TrimSuffix(domain, ".ip6.arpa.")
|
||||
// Split by dots and reverse
|
||||
parts := strings.Split(ipPart, ".")
|
||||
if len(parts) != 32 {
|
||||
return nil
|
||||
}
|
||||
// Reverse the nibbles and group into 16-bit hex values
|
||||
reversed := make([]string, 32)
|
||||
for i := 0; i < 32; i++ {
|
||||
reversed[i] = parts[31-i]
|
||||
}
|
||||
// Join into IPv6 format (groups of 4 nibbles separated by colons)
|
||||
var ipv6Parts []string
|
||||
for i := 0; i < 32; i += 4 {
|
||||
ipv6Parts = append(ipv6Parts, reversed[i]+reversed[i+1]+reversed[i+2]+reversed[i+3])
|
||||
}
|
||||
// Parse as IP
|
||||
return net.ParseIP(strings.Join(ipv6Parts, ":"))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IPToReverseDNS converts an IP address to reverse DNS format
|
||||
// Returns the domain name for PTR queries (e.g., "1.0.0.127.in-addr.arpa.")
|
||||
func IPToReverseDNS(ip net.IP) string {
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
// IPv4: reverse octets and append .in-addr.arpa.
|
||||
return dns.Fqdn(fmt.Sprintf("%d.%d.%d.%d.in-addr.arpa",
|
||||
ip4[3], ip4[2], ip4[1], ip4[0]))
|
||||
}
|
||||
|
||||
if ip6 := ip.To16(); ip6 != nil && ip.To4() == nil {
|
||||
// IPv6: expand to 32 nibbles, reverse, and append .ip6.arpa.
|
||||
var nibbles []string
|
||||
for i := 15; i >= 0; i-- {
|
||||
nibbles = append(nibbles, fmt.Sprintf("%x", ip6[i]&0x0f))
|
||||
nibbles = append(nibbles, fmt.Sprintf("%x", ip6[i]>>4))
|
||||
}
|
||||
return dns.Fqdn(strings.Join(nibbles, ".") + ".ip6.arpa")
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -183,25 +183,34 @@ func TestDNSRecordStoreWildcard(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test exact match takes precedence
|
||||
ips := store.GetRecords("exact.autoco.internal.", RecordTypeA)
|
||||
ips, exists := store.GetRecords("exact.autoco.internal.", RecordTypeA)
|
||||
if !exists {
|
||||
t.Error("Expected domain to exist")
|
||||
}
|
||||
if len(ips) != 1 {
|
||||
t.Errorf("Expected 1 IP for exact match, got %d", len(ips))
|
||||
}
|
||||
if !ips[0].Equal(exactIP) {
|
||||
if len(ips) > 0 && !ips[0].Equal(exactIP) {
|
||||
t.Errorf("Expected exact IP %v, got %v", exactIP, ips[0])
|
||||
}
|
||||
|
||||
// Test wildcard match
|
||||
ips = store.GetRecords("host.autoco.internal.", RecordTypeA)
|
||||
ips, exists = store.GetRecords("host.autoco.internal.", RecordTypeA)
|
||||
if !exists {
|
||||
t.Error("Expected wildcard match to exist")
|
||||
}
|
||||
if len(ips) != 1 {
|
||||
t.Errorf("Expected 1 IP for wildcard match, got %d", len(ips))
|
||||
}
|
||||
if !ips[0].Equal(wildcardIP) {
|
||||
if len(ips) > 0 && !ips[0].Equal(wildcardIP) {
|
||||
t.Errorf("Expected wildcard IP %v, got %v", wildcardIP, ips[0])
|
||||
}
|
||||
|
||||
// Test non-match (base domain)
|
||||
ips = store.GetRecords("autoco.internal.", RecordTypeA)
|
||||
ips, exists = store.GetRecords("autoco.internal.", RecordTypeA)
|
||||
if exists {
|
||||
t.Error("Expected base domain to not exist")
|
||||
}
|
||||
if len(ips) != 0 {
|
||||
t.Errorf("Expected 0 IPs for base domain, got %d", len(ips))
|
||||
}
|
||||
@@ -218,7 +227,10 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test matching domain
|
||||
ips := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA)
|
||||
ips, exists := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA)
|
||||
if !exists {
|
||||
t.Error("Expected complex wildcard match to exist")
|
||||
}
|
||||
if len(ips) != 1 {
|
||||
t.Errorf("Expected 1 IP for complex wildcard match, got %d", len(ips))
|
||||
}
|
||||
@@ -227,13 +239,19 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test non-matching domain (missing prefix)
|
||||
ips = store.GetRecords("host-01.autoco.internal.", RecordTypeA)
|
||||
ips, exists = store.GetRecords("host-01.autoco.internal.", RecordTypeA)
|
||||
if exists {
|
||||
t.Error("Expected domain without prefix to not exist")
|
||||
}
|
||||
if len(ips) != 0 {
|
||||
t.Errorf("Expected 0 IPs for domain without prefix, got %d", len(ips))
|
||||
}
|
||||
|
||||
// Test non-matching domain (wrong ? position)
|
||||
ips = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA)
|
||||
ips, exists = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA)
|
||||
if exists {
|
||||
t.Error("Expected domain with wrong ? match to not exist")
|
||||
}
|
||||
if len(ips) != 0 {
|
||||
t.Errorf("Expected 0 IPs for domain with wrong ? match, got %d", len(ips))
|
||||
}
|
||||
@@ -250,7 +268,10 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify it exists
|
||||
ips := store.GetRecords("host.autoco.internal.", RecordTypeA)
|
||||
ips, exists := store.GetRecords("host.autoco.internal.", RecordTypeA)
|
||||
if !exists {
|
||||
t.Error("Expected domain to exist before removal")
|
||||
}
|
||||
if len(ips) != 1 {
|
||||
t.Errorf("Expected 1 IP before removal, got %d", len(ips))
|
||||
}
|
||||
@@ -259,7 +280,10 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) {
|
||||
store.RemoveRecord("*.autoco.internal", nil)
|
||||
|
||||
// Verify it's gone
|
||||
ips = store.GetRecords("host.autoco.internal.", RecordTypeA)
|
||||
ips, exists = store.GetRecords("host.autoco.internal.", RecordTypeA)
|
||||
if exists {
|
||||
t.Error("Expected domain to not exist after removal")
|
||||
}
|
||||
if len(ips) != 0 {
|
||||
t.Errorf("Expected 0 IPs after removal, got %d", len(ips))
|
||||
}
|
||||
@@ -290,19 +314,19 @@ func TestDNSRecordStoreMultipleWildcards(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test domain matching only the prod pattern and the broad pattern
|
||||
ips := store.GetRecords("host.prod.autoco.internal.", RecordTypeA)
|
||||
ips, _ := store.GetRecords("host.prod.autoco.internal.", RecordTypeA)
|
||||
if len(ips) != 2 {
|
||||
t.Errorf("Expected 2 IPs (prod + broad), got %d", len(ips))
|
||||
}
|
||||
|
||||
// Test domain matching only the dev pattern and the broad pattern
|
||||
ips = store.GetRecords("service.dev.autoco.internal.", RecordTypeA)
|
||||
ips, _ = store.GetRecords("service.dev.autoco.internal.", RecordTypeA)
|
||||
if len(ips) != 2 {
|
||||
t.Errorf("Expected 2 IPs (dev + broad), got %d", len(ips))
|
||||
}
|
||||
|
||||
// Test domain matching only the broad pattern
|
||||
ips = store.GetRecords("host.test.autoco.internal.", RecordTypeA)
|
||||
ips, _ = store.GetRecords("host.test.autoco.internal.", RecordTypeA)
|
||||
if len(ips) != 1 {
|
||||
t.Errorf("Expected 1 IP (broad only), got %d", len(ips))
|
||||
}
|
||||
@@ -319,7 +343,7 @@ func TestDNSRecordStoreIPv6Wildcard(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test wildcard match for IPv6
|
||||
ips := store.GetRecords("host.autoco.internal.", RecordTypeAAAA)
|
||||
ips, _ := store.GetRecords("host.autoco.internal.", RecordTypeAAAA)
|
||||
if len(ips) != 1 {
|
||||
t.Errorf("Expected 1 IPv6 for wildcard match, got %d", len(ips))
|
||||
}
|
||||
@@ -368,7 +392,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, domain := range testCases {
|
||||
ips := store.GetRecords(domain, RecordTypeA)
|
||||
ips, _ := store.GetRecords(domain, RecordTypeA)
|
||||
if len(ips) != 1 {
|
||||
t.Errorf("Expected 1 IP for domain %q, got %d", domain, len(ips))
|
||||
}
|
||||
@@ -392,7 +416,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, domain := range wildcardTestCases {
|
||||
ips := store.GetRecords(domain, RecordTypeA)
|
||||
ips, _ := store.GetRecords(domain, RecordTypeA)
|
||||
if len(ips) != 1 {
|
||||
t.Errorf("Expected 1 IP for wildcard domain %q, got %d", domain, len(ips))
|
||||
}
|
||||
@@ -403,7 +427,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
|
||||
|
||||
// Test removal with different case
|
||||
store.RemoveRecord("MYHOST.AUTOCO.INTERNAL", nil)
|
||||
ips := store.GetRecords("myhost.autoco.internal.", RecordTypeA)
|
||||
ips, _ := store.GetRecords("myhost.autoco.internal.", RecordTypeA)
|
||||
if len(ips) != 0 {
|
||||
t.Errorf("Expected 0 IPs after removal, got %d", len(ips))
|
||||
}
|
||||
@@ -413,3 +437,452 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
|
||||
t.Error("Expected HasRecord to return true for mixed case wildcard match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPTRRecordIPv4(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add PTR record for IPv4
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
domain := "host.example.com."
|
||||
err := store.AddPTRRecord(ip, domain)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add PTR record: %v", err)
|
||||
}
|
||||
|
||||
// Test reverse DNS lookup
|
||||
reverseDomain := "1.1.168.192.in-addr.arpa."
|
||||
result, ok := store.GetPTRRecord(reverseDomain)
|
||||
if !ok {
|
||||
t.Error("Expected PTR record to be found")
|
||||
}
|
||||
if result != domain {
|
||||
t.Errorf("Expected domain %q, got %q", domain, result)
|
||||
}
|
||||
|
||||
// Test HasPTRRecord
|
||||
if !store.HasPTRRecord(reverseDomain) {
|
||||
t.Error("Expected HasPTRRecord to return true")
|
||||
}
|
||||
|
||||
// Test non-existent PTR record
|
||||
_, ok = store.GetPTRRecord("2.1.168.192.in-addr.arpa.")
|
||||
if ok {
|
||||
t.Error("Expected PTR record not to be found for different IP")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPTRRecordIPv6(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add PTR record for IPv6
|
||||
ip := net.ParseIP("2001:db8::1")
|
||||
domain := "ipv6host.example.com."
|
||||
err := store.AddPTRRecord(ip, domain)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add PTR record: %v", err)
|
||||
}
|
||||
|
||||
// Test reverse DNS lookup
|
||||
// 2001:db8::1 = 2001:0db8:0000:0000:0000:0000:0000:0001
|
||||
// Reverse: 1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.
|
||||
reverseDomain := "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa."
|
||||
result, ok := store.GetPTRRecord(reverseDomain)
|
||||
if !ok {
|
||||
t.Error("Expected IPv6 PTR record to be found")
|
||||
}
|
||||
if result != domain {
|
||||
t.Errorf("Expected domain %q, got %q", domain, result)
|
||||
}
|
||||
|
||||
// Test HasPTRRecord
|
||||
if !store.HasPTRRecord(reverseDomain) {
|
||||
t.Error("Expected HasPTRRecord to return true for IPv6")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemovePTRRecord(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add PTR record
|
||||
ip := net.ParseIP("10.0.0.1")
|
||||
domain := "test.example.com."
|
||||
err := store.AddPTRRecord(ip, domain)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add PTR record: %v", err)
|
||||
}
|
||||
|
||||
// Verify it exists
|
||||
reverseDomain := "1.0.0.10.in-addr.arpa."
|
||||
_, ok := store.GetPTRRecord(reverseDomain)
|
||||
if !ok {
|
||||
t.Error("Expected PTR record to exist before removal")
|
||||
}
|
||||
|
||||
// Remove PTR record
|
||||
store.RemovePTRRecord(ip)
|
||||
|
||||
// Verify it's gone
|
||||
_, ok = store.GetPTRRecord(reverseDomain)
|
||||
if ok {
|
||||
t.Error("Expected PTR record to be removed")
|
||||
}
|
||||
|
||||
// Test HasPTRRecord after removal
|
||||
if store.HasPTRRecord(reverseDomain) {
|
||||
t.Error("Expected HasPTRRecord to return false after removal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPToReverseDNS(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "IPv4 simple",
|
||||
ip: "192.168.1.1",
|
||||
expected: "1.1.168.192.in-addr.arpa.",
|
||||
},
|
||||
{
|
||||
name: "IPv4 localhost",
|
||||
ip: "127.0.0.1",
|
||||
expected: "1.0.0.127.in-addr.arpa.",
|
||||
},
|
||||
{
|
||||
name: "IPv4 with zeros",
|
||||
ip: "10.0.0.1",
|
||||
expected: "1.0.0.10.in-addr.arpa.",
|
||||
},
|
||||
{
|
||||
name: "IPv6 simple",
|
||||
ip: "2001:db8::1",
|
||||
expected: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
|
||||
},
|
||||
{
|
||||
name: "IPv6 localhost",
|
||||
ip: "::1",
|
||||
expected: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ip := net.ParseIP(tt.ip)
|
||||
if ip == nil {
|
||||
t.Fatalf("Failed to parse IP: %s", tt.ip)
|
||||
}
|
||||
result := IPToReverseDNS(ip)
|
||||
if result != tt.expected {
|
||||
t.Errorf("IPToReverseDNS(%s) = %q, want %q", tt.ip, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseDNSToIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reverseDNS string
|
||||
expectedIP string
|
||||
shouldMatch bool
|
||||
}{
|
||||
{
|
||||
name: "IPv4 simple",
|
||||
reverseDNS: "1.1.168.192.in-addr.arpa.",
|
||||
expectedIP: "192.168.1.1",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "IPv4 localhost",
|
||||
reverseDNS: "1.0.0.127.in-addr.arpa.",
|
||||
expectedIP: "127.0.0.1",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 simple",
|
||||
reverseDNS: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
|
||||
expectedIP: "2001:db8::1",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid IPv4 format",
|
||||
reverseDNS: "1.1.168.in-addr.arpa.",
|
||||
expectedIP: "",
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid IPv6 format",
|
||||
reverseDNS: "1.0.0.0.ip6.arpa.",
|
||||
expectedIP: "",
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "Not a reverse DNS domain",
|
||||
reverseDNS: "example.com.",
|
||||
expectedIP: "",
|
||||
shouldMatch: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := reverseDNSToIP(tt.reverseDNS)
|
||||
if tt.shouldMatch {
|
||||
if result == nil {
|
||||
t.Errorf("reverseDNSToIP(%q) returned nil, expected IP", tt.reverseDNS)
|
||||
return
|
||||
}
|
||||
expectedIP := net.ParseIP(tt.expectedIP)
|
||||
if !result.Equal(expectedIP) {
|
||||
t.Errorf("reverseDNSToIP(%q) = %v, want %v", tt.reverseDNS, result, expectedIP)
|
||||
}
|
||||
} else {
|
||||
if result != nil {
|
||||
t.Errorf("reverseDNSToIP(%q) = %v, expected nil", tt.reverseDNS, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPTRRecordCaseInsensitive(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add PTR record with mixed case domain
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
domain := "MyHost.Example.Com"
|
||||
err := store.AddPTRRecord(ip, domain)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add PTR record: %v", err)
|
||||
}
|
||||
|
||||
// Test lookup with different cases in reverse DNS format
|
||||
reverseDomain := "1.1.168.192.in-addr.arpa."
|
||||
result, ok := store.GetPTRRecord(reverseDomain)
|
||||
if !ok {
|
||||
t.Error("Expected PTR record to be found")
|
||||
}
|
||||
// Domain should be normalized to lowercase
|
||||
if result != "myhost.example.com." {
|
||||
t.Errorf("Expected normalized domain %q, got %q", "myhost.example.com.", result)
|
||||
}
|
||||
|
||||
// Test with uppercase reverse DNS
|
||||
reverseDomainUpper := "1.1.168.192.IN-ADDR.ARPA."
|
||||
result, ok = store.GetPTRRecord(reverseDomainUpper)
|
||||
if !ok {
|
||||
t.Error("Expected PTR record to be found with uppercase reverse DNS")
|
||||
}
|
||||
if result != "myhost.example.com." {
|
||||
t.Errorf("Expected normalized domain %q, got %q", "myhost.example.com.", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearPTRRecords(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add some PTR records
|
||||
ip1 := net.ParseIP("192.168.1.1")
|
||||
ip2 := net.ParseIP("192.168.1.2")
|
||||
store.AddPTRRecord(ip1, "host1.example.com.")
|
||||
store.AddPTRRecord(ip2, "host2.example.com.")
|
||||
|
||||
// Add some A records too
|
||||
store.AddRecord("test.example.com.", net.ParseIP("10.0.0.1"))
|
||||
|
||||
// Verify PTR records exist
|
||||
if !store.HasPTRRecord("1.1.168.192.in-addr.arpa.") {
|
||||
t.Error("Expected PTR record to exist before clear")
|
||||
}
|
||||
|
||||
// Clear all records
|
||||
store.Clear()
|
||||
|
||||
// Verify PTR records are gone
|
||||
if store.HasPTRRecord("1.1.168.192.in-addr.arpa.") {
|
||||
t.Error("Expected PTR record to be cleared")
|
||||
}
|
||||
if store.HasPTRRecord("2.1.168.192.in-addr.arpa.") {
|
||||
t.Error("Expected PTR record to be cleared")
|
||||
}
|
||||
|
||||
// Verify A records are also gone
|
||||
if store.HasRecord("test.example.com.", RecordTypeA) {
|
||||
t.Error("Expected A record to be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutomaticPTRRecordOnAdd(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add an A record - should automatically add PTR record
|
||||
domain := "host.example.com."
|
||||
ip := net.ParseIP("192.168.1.100")
|
||||
err := store.AddRecord(domain, ip)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add A record: %v", err)
|
||||
}
|
||||
|
||||
// Verify PTR record was automatically created
|
||||
reverseDomain := "100.1.168.192.in-addr.arpa."
|
||||
result, ok := store.GetPTRRecord(reverseDomain)
|
||||
if !ok {
|
||||
t.Error("Expected PTR record to be automatically created")
|
||||
}
|
||||
if result != domain {
|
||||
t.Errorf("Expected PTR to point to %q, got %q", domain, result)
|
||||
}
|
||||
|
||||
// Add AAAA record - should also automatically add PTR record
|
||||
domain6 := "ipv6host.example.com."
|
||||
ip6 := net.ParseIP("2001:db8::1")
|
||||
err = store.AddRecord(domain6, ip6)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add AAAA record: %v", err)
|
||||
}
|
||||
|
||||
// Verify IPv6 PTR record was automatically created
|
||||
reverseDomain6 := "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa."
|
||||
result6, ok := store.GetPTRRecord(reverseDomain6)
|
||||
if !ok {
|
||||
t.Error("Expected IPv6 PTR record to be automatically created")
|
||||
}
|
||||
if result6 != domain6 {
|
||||
t.Errorf("Expected PTR to point to %q, got %q", domain6, result6)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutomaticPTRRecordOnRemove(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add an A record (with automatic PTR)
|
||||
domain := "host.example.com."
|
||||
ip := net.ParseIP("192.168.1.100")
|
||||
store.AddRecord(domain, ip)
|
||||
|
||||
// Verify PTR exists
|
||||
reverseDomain := "100.1.168.192.in-addr.arpa."
|
||||
if !store.HasPTRRecord(reverseDomain) {
|
||||
t.Error("Expected PTR record to exist after adding A record")
|
||||
}
|
||||
|
||||
// Remove the A record
|
||||
store.RemoveRecord(domain, ip)
|
||||
|
||||
// Verify PTR was automatically removed
|
||||
if store.HasPTRRecord(reverseDomain) {
|
||||
t.Error("Expected PTR record to be automatically removed")
|
||||
}
|
||||
|
||||
// Verify A record is also gone
|
||||
ips, _ := store.GetRecords(domain, RecordTypeA)
|
||||
if len(ips) != 0 {
|
||||
t.Errorf("Expected A record to be removed, got %d records", len(ips))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutomaticPTRRecordOnRemoveAll(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add multiple IPs for the same domain
|
||||
domain := "host.example.com."
|
||||
ip1 := net.ParseIP("192.168.1.100")
|
||||
ip2 := net.ParseIP("192.168.1.101")
|
||||
store.AddRecord(domain, ip1)
|
||||
store.AddRecord(domain, ip2)
|
||||
|
||||
// Verify both PTR records exist
|
||||
reverseDomain1 := "100.1.168.192.in-addr.arpa."
|
||||
reverseDomain2 := "101.1.168.192.in-addr.arpa."
|
||||
if !store.HasPTRRecord(reverseDomain1) {
|
||||
t.Error("Expected first PTR record to exist")
|
||||
}
|
||||
if !store.HasPTRRecord(reverseDomain2) {
|
||||
t.Error("Expected second PTR record to exist")
|
||||
}
|
||||
|
||||
// Remove all records for the domain
|
||||
store.RemoveRecord(domain, nil)
|
||||
|
||||
// Verify both PTR records were removed
|
||||
if store.HasPTRRecord(reverseDomain1) {
|
||||
t.Error("Expected first PTR record to be removed")
|
||||
}
|
||||
if store.HasPTRRecord(reverseDomain2) {
|
||||
t.Error("Expected second PTR record to be removed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoPTRForWildcardRecords(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add wildcard record - should NOT create PTR record
|
||||
domain := "*.example.com."
|
||||
ip := net.ParseIP("192.168.1.100")
|
||||
err := store.AddRecord(domain, ip)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add wildcard record: %v", err)
|
||||
}
|
||||
|
||||
// Verify no PTR record was created
|
||||
reverseDomain := "100.1.168.192.in-addr.arpa."
|
||||
_, ok := store.GetPTRRecord(reverseDomain)
|
||||
if ok {
|
||||
t.Error("Expected no PTR record for wildcard domain")
|
||||
}
|
||||
|
||||
// Verify wildcard A record exists
|
||||
if !store.HasRecord("host.example.com.", RecordTypeA) {
|
||||
t.Error("Expected wildcard A record to exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPTRRecordOverwrite(t *testing.T) {
|
||||
store := NewDNSRecordStore()
|
||||
|
||||
// Add first domain with IP
|
||||
domain1 := "host1.example.com."
|
||||
ip := net.ParseIP("192.168.1.100")
|
||||
store.AddRecord(domain1, ip)
|
||||
|
||||
// Verify PTR points to first domain
|
||||
reverseDomain := "100.1.168.192.in-addr.arpa."
|
||||
result, ok := store.GetPTRRecord(reverseDomain)
|
||||
if !ok {
|
||||
t.Fatal("Expected PTR record to exist")
|
||||
}
|
||||
if result != domain1 {
|
||||
t.Errorf("Expected PTR to point to %q, got %q", domain1, result)
|
||||
}
|
||||
|
||||
// Add second domain with same IP - should overwrite PTR
|
||||
domain2 := "host2.example.com."
|
||||
store.AddRecord(domain2, ip)
|
||||
|
||||
// Verify PTR now points to second domain (last one added)
|
||||
result, ok = store.GetPTRRecord(reverseDomain)
|
||||
if !ok {
|
||||
t.Fatal("Expected PTR record to still exist")
|
||||
}
|
||||
if result != domain2 {
|
||||
t.Errorf("Expected PTR to point to %q (overwritten), got %q", domain2, result)
|
||||
}
|
||||
|
||||
// Remove first domain - PTR should remain pointing to second domain
|
||||
store.RemoveRecord(domain1, ip)
|
||||
result, ok = store.GetPTRRecord(reverseDomain)
|
||||
if !ok {
|
||||
t.Error("Expected PTR record to still exist after removing first domain")
|
||||
}
|
||||
if result != domain2 {
|
||||
t.Errorf("Expected PTR to still point to %q, got %q", domain2, result)
|
||||
}
|
||||
|
||||
// Remove second domain - PTR should now be gone
|
||||
store.RemoveRecord(domain2, ip)
|
||||
_, ok = store.GetPTRRecord(reverseDomain)
|
||||
if ok {
|
||||
t.Error("Expected PTR record to be removed after removing second domain")
|
||||
}
|
||||
}
|
||||
|
||||
2
olm.iss
2
olm.iss
@@ -32,7 +32,7 @@ DefaultGroupName={#MyAppName}
|
||||
DisableProgramGroupPage=yes
|
||||
; Uncomment the following line to run in non administrative install mode (install for current user only).
|
||||
;PrivilegesRequired=lowest
|
||||
OutputBaseFilename=mysetup
|
||||
OutputBaseFilename=olm_windows_installer
|
||||
SolidCompression=yes
|
||||
WizardStyle=modern
|
||||
; Add this to ensure PATH changes are applied and the system is prompted for a restart if needed
|
||||
|
||||
14
olm/data.go
14
olm/data.go
@@ -2,6 +2,7 @@ package olm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/holepunch"
|
||||
@@ -220,6 +221,7 @@ func (o *Olm) handleSync(msg websocket.WSMessage) {
|
||||
logger.Info("Sync: Adding new peer for site %d", siteId)
|
||||
|
||||
o.holePunchManager.TriggerHolePunch()
|
||||
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
|
||||
|
||||
// // TODO: do we need to send the message to the cloud to add the peer that way?
|
||||
// if err := o.peerManager.AddPeer(expectedSite); err != nil {
|
||||
@@ -230,9 +232,17 @@ func (o *Olm) handleSync(msg websocket.WSMessage) {
|
||||
|
||||
// add the peer via the server
|
||||
// this is important because newt needs to get triggered as well to add the peer once the hp is complete
|
||||
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
|
||||
chainId := fmt.Sprintf("sync-%d", expectedSite.SiteId)
|
||||
o.peerSendMu.Lock()
|
||||
if stop, ok := o.stopPeerSends[chainId]; ok {
|
||||
stop()
|
||||
}
|
||||
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
|
||||
"siteId": expectedSite.SiteId,
|
||||
}, 1*time.Second, 10)
|
||||
"chainId": chainId,
|
||||
}, 2*time.Second, 10)
|
||||
o.stopPeerSends[chainId] = stopFunc
|
||||
o.peerSendMu.Unlock()
|
||||
|
||||
} else {
|
||||
// Existing peer - check if update is needed
|
||||
|
||||
49
olm/olm.go
49
olm/olm.go
@@ -2,6 +2,8 @@ package olm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -65,7 +67,9 @@ type Olm struct {
|
||||
stopRegister func()
|
||||
updateRegister func(newData any)
|
||||
|
||||
stopPeerSend func()
|
||||
stopPeerSends map[string]func()
|
||||
stopPeerInits map[string]func()
|
||||
peerSendMu sync.Mutex
|
||||
|
||||
// WaitGroup to track tunnel lifecycle
|
||||
tunnelWg sync.WaitGroup
|
||||
@@ -116,6 +120,13 @@ func (o *Olm) initTunnelInfo(clientID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateChainId generates a random chain ID for tracking peer sender lifecycles.
|
||||
func generateChainId() string {
|
||||
b := make([]byte, 8)
|
||||
_, _ = rand.Read(b)
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func Init(ctx context.Context, config OlmConfig) (*Olm, error) {
|
||||
logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel))
|
||||
|
||||
@@ -170,6 +181,8 @@ func Init(ctx context.Context, config OlmConfig) (*Olm, error) {
|
||||
olmCtx: ctx,
|
||||
apiServer: apiServer,
|
||||
olmConfig: config,
|
||||
stopPeerSends: make(map[string]func()),
|
||||
stopPeerInits: make(map[string]func()),
|
||||
}
|
||||
|
||||
newOlm.registerAPICallbacks()
|
||||
@@ -284,6 +297,21 @@ func (o *Olm) registerAPICallbacks() {
|
||||
logger.Info("Processing power mode change request via API: mode=%s", req.Mode)
|
||||
return o.SetPowerMode(req.Mode)
|
||||
},
|
||||
func(req api.JITConnectionRequest) error {
|
||||
logger.Info("Processing JIT connect request via API: site=%s resource=%s", req.Site, req.Resource)
|
||||
|
||||
chainId := generateChainId()
|
||||
o.peerSendMu.Lock()
|
||||
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/init", map[string]interface{}{
|
||||
"siteId": req.Site,
|
||||
"resourceId": req.Resource,
|
||||
"chainId": chainId,
|
||||
}, 2*time.Second, 10)
|
||||
o.stopPeerInits[chainId] = stopFunc
|
||||
o.peerSendMu.Unlock()
|
||||
|
||||
return nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -378,6 +406,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
||||
|
||||
// Handler for peer handshake - adds exit node to holepunch rotation and notifies server
|
||||
o.websocket.RegisterHandler("olm/wg/peer/holepunch/site/add", o.handleWgPeerHolepunchAddSite)
|
||||
o.websocket.RegisterHandler("olm/wg/peer/chain/cancel", o.handleCancelChain)
|
||||
o.websocket.RegisterHandler("olm/sync", o.handleSync)
|
||||
|
||||
o.websocket.OnConnect(func() error {
|
||||
@@ -420,7 +449,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
||||
"userToken": userToken,
|
||||
"fingerprint": o.fingerprint,
|
||||
"postures": o.postures,
|
||||
}, 1*time.Second, 10)
|
||||
}, 2*time.Second, 10)
|
||||
|
||||
// Invoke onRegistered callback if configured
|
||||
if o.olmConfig.OnRegistered != nil {
|
||||
@@ -517,6 +546,22 @@ func (o *Olm) Close() {
|
||||
o.stopRegister = nil
|
||||
}
|
||||
|
||||
// Stop all pending peer init and send senders before closing websocket
|
||||
o.peerSendMu.Lock()
|
||||
for _, stop := range o.stopPeerInits {
|
||||
if stop != nil {
|
||||
stop()
|
||||
}
|
||||
}
|
||||
o.stopPeerInits = make(map[string]func())
|
||||
for _, stop := range o.stopPeerSends {
|
||||
if stop != nil {
|
||||
stop()
|
||||
}
|
||||
}
|
||||
o.stopPeerSends = make(map[string]func())
|
||||
o.peerSendMu.Unlock()
|
||||
|
||||
// send a disconnect message to the cloud to show disconnected
|
||||
if o.websocket != nil {
|
||||
o.websocket.SendMessage("olm/disconnecting", map[string]any{})
|
||||
|
||||
114
olm/peer.go
114
olm/peer.go
@@ -20,31 +20,38 @@ func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) {
|
||||
return
|
||||
}
|
||||
|
||||
if o.stopPeerSend != nil {
|
||||
o.stopPeerSend()
|
||||
o.stopPeerSend = nil
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var siteConfig peers.SiteConfig
|
||||
if err := json.Unmarshal(jsonData, &siteConfig); err != nil {
|
||||
var siteConfigMsg struct {
|
||||
peers.SiteConfig
|
||||
ChainId string `json:"chainId"`
|
||||
}
|
||||
if err := json.Unmarshal(jsonData, &siteConfigMsg); err != nil {
|
||||
logger.Error("Error unmarshaling add data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if siteConfigMsg.ChainId != "" {
|
||||
o.peerSendMu.Lock()
|
||||
if stop, ok := o.stopPeerSends[siteConfigMsg.ChainId]; ok {
|
||||
stop()
|
||||
delete(o.stopPeerSends, siteConfigMsg.ChainId)
|
||||
}
|
||||
o.peerSendMu.Unlock()
|
||||
}
|
||||
|
||||
_ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it
|
||||
|
||||
if err := o.peerManager.AddPeer(siteConfig); err != nil {
|
||||
if err := o.peerManager.AddPeer(siteConfigMsg.SiteConfig); err != nil {
|
||||
logger.Error("Failed to add peer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Successfully added peer for site %d", siteConfig.SiteId)
|
||||
logger.Info("Successfully added peer for site %d", siteConfigMsg.SiteId)
|
||||
}
|
||||
|
||||
func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) {
|
||||
@@ -164,12 +171,19 @@ func (o *Olm) handleWgPeerRelay(msg websocket.WSMessage) {
|
||||
return
|
||||
}
|
||||
|
||||
var relayData peers.RelayPeerData
|
||||
var relayData struct {
|
||||
peers.RelayPeerData
|
||||
ChainId string `json:"chainId"`
|
||||
}
|
||||
if err := json.Unmarshal(jsonData, &relayData); err != nil {
|
||||
logger.Error("Error unmarshaling relay data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if monitor := o.peerManager.GetPeerMonitor(); monitor != nil {
|
||||
monitor.CancelRelaySend(relayData.ChainId)
|
||||
}
|
||||
|
||||
primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve primary relay endpoint: %v", err)
|
||||
@@ -197,12 +211,19 @@ func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) {
|
||||
return
|
||||
}
|
||||
|
||||
var relayData peers.UnRelayPeerData
|
||||
var relayData struct {
|
||||
peers.UnRelayPeerData
|
||||
ChainId string `json:"chainId"`
|
||||
}
|
||||
if err := json.Unmarshal(jsonData, &relayData); err != nil {
|
||||
logger.Error("Error unmarshaling relay data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if monitor := o.peerManager.GetPeerMonitor(); monitor != nil {
|
||||
monitor.CancelRelaySend(relayData.ChainId)
|
||||
}
|
||||
|
||||
primaryRelay, err := util.ResolveDomain(relayData.Endpoint)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
|
||||
@@ -231,6 +252,7 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
|
||||
|
||||
var handshakeData struct {
|
||||
SiteId int `json:"siteId"`
|
||||
ChainId string `json:"chainId"`
|
||||
ExitNode struct {
|
||||
PublicKey string `json:"publicKey"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
@@ -243,6 +265,16 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
|
||||
return
|
||||
}
|
||||
|
||||
// Stop the peer init sender for this chain, if any
|
||||
if handshakeData.ChainId != "" {
|
||||
o.peerSendMu.Lock()
|
||||
if stop, ok := o.stopPeerInits[handshakeData.ChainId]; ok {
|
||||
stop()
|
||||
delete(o.stopPeerInits, handshakeData.ChainId)
|
||||
}
|
||||
o.peerSendMu.Unlock()
|
||||
}
|
||||
|
||||
// Get existing peer from PeerManager
|
||||
_, exists := o.peerManager.GetPeer(handshakeData.SiteId)
|
||||
if exists {
|
||||
@@ -273,10 +305,64 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
|
||||
o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt
|
||||
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
|
||||
|
||||
// Send handshake acknowledgment back to server with retry
|
||||
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
|
||||
// Send handshake acknowledgment back to server with retry, keyed by chainId
|
||||
chainId := handshakeData.ChainId
|
||||
if chainId == "" {
|
||||
chainId = generateChainId()
|
||||
}
|
||||
o.peerSendMu.Lock()
|
||||
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
|
||||
"siteId": handshakeData.SiteId,
|
||||
}, 1*time.Second, 10)
|
||||
"chainId": chainId,
|
||||
}, 2*time.Second, 10)
|
||||
o.stopPeerSends[chainId] = stopFunc
|
||||
o.peerSendMu.Unlock()
|
||||
|
||||
logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint)
|
||||
}
|
||||
|
||||
func (o *Olm) handleCancelChain(msg websocket.WSMessage) {
|
||||
logger.Debug("Received cancel-chain message: %v", msg.Data)
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling cancel-chain data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var cancelData struct {
|
||||
ChainId string `json:"chainId"`
|
||||
}
|
||||
if err := json.Unmarshal(jsonData, &cancelData); err != nil {
|
||||
logger.Error("Error unmarshaling cancel-chain data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if cancelData.ChainId == "" {
|
||||
logger.Warn("Received cancel-chain message with no chainId")
|
||||
return
|
||||
}
|
||||
|
||||
o.peerSendMu.Lock()
|
||||
defer o.peerSendMu.Unlock()
|
||||
|
||||
found := false
|
||||
|
||||
if stop, ok := o.stopPeerInits[cancelData.ChainId]; ok {
|
||||
stop()
|
||||
delete(o.stopPeerInits, cancelData.ChainId)
|
||||
found = true
|
||||
}
|
||||
|
||||
if stop, ok := o.stopPeerSends[cancelData.ChainId]; ok {
|
||||
stop()
|
||||
delete(o.stopPeerSends, cancelData.ChainId)
|
||||
found = true
|
||||
}
|
||||
|
||||
if found {
|
||||
logger.Info("Cancelled chain %s", cancelData.ChainId)
|
||||
} else {
|
||||
logger.Warn("Cancel-chain: no active sender found for chain %s", cancelData.ChainId)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package monitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -35,6 +37,10 @@ type PeerMonitor struct {
|
||||
maxAttempts int
|
||||
wsClient *websocket.Client
|
||||
|
||||
// Relay sender tracking
|
||||
relaySends map[string]func()
|
||||
relaySendMu sync.Mutex
|
||||
|
||||
// Netstack fields
|
||||
middleDev *middleDevice.MiddleDevice
|
||||
localIP string
|
||||
@@ -82,6 +88,12 @@ type PeerMonitor struct {
|
||||
}
|
||||
|
||||
// NewPeerMonitor creates a new peer monitor with the given callback
|
||||
func generateChainId() string {
|
||||
b := make([]byte, 8)
|
||||
_, _ = rand.Read(b)
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind, apiServer *api.API) *PeerMonitor {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pm := &PeerMonitor{
|
||||
@@ -99,6 +111,7 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
|
||||
holepunchEndpoints: make(map[int]string),
|
||||
holepunchStatus: make(map[int]bool),
|
||||
relayedPeers: make(map[int]bool),
|
||||
relaySends: make(map[string]func()),
|
||||
holepunchMaxAttempts: 3, // Trigger relay after 3 consecutive failures
|
||||
holepunchFailures: make(map[int]int),
|
||||
// Rapid initial test settings: complete within ~1.5 seconds
|
||||
@@ -396,20 +409,23 @@ func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status Connectio
|
||||
}
|
||||
}
|
||||
|
||||
// sendRelay sends a relay message to the server
|
||||
// sendRelay sends a relay message to the server with retry, keyed by chainId
|
||||
func (pm *PeerMonitor) sendRelay(siteID int) error {
|
||||
if pm.wsClient == nil {
|
||||
return fmt.Errorf("websocket client is nil")
|
||||
}
|
||||
|
||||
err := pm.wsClient.SendMessage("olm/wg/relay", map[string]interface{}{
|
||||
chainId := generateChainId()
|
||||
stopFunc, _ := pm.wsClient.SendMessageInterval("olm/wg/relay", map[string]interface{}{
|
||||
"siteId": siteID,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to send registration message: %v", err)
|
||||
return err
|
||||
}
|
||||
logger.Info("Sent relay message")
|
||||
"chainId": chainId,
|
||||
}, 2*time.Second, 10)
|
||||
|
||||
pm.relaySendMu.Lock()
|
||||
pm.relaySends[chainId] = stopFunc
|
||||
pm.relaySendMu.Unlock()
|
||||
|
||||
logger.Info("Sent relay message for site %d (chain %s)", siteID, chainId)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -419,23 +435,52 @@ func (pm *PeerMonitor) RequestRelay(siteID int) error {
|
||||
return pm.sendRelay(siteID)
|
||||
}
|
||||
|
||||
// sendUnRelay sends an unrelay message to the server
|
||||
// sendUnRelay sends an unrelay message to the server with retry, keyed by chainId
|
||||
func (pm *PeerMonitor) sendUnRelay(siteID int) error {
|
||||
if pm.wsClient == nil {
|
||||
return fmt.Errorf("websocket client is nil")
|
||||
}
|
||||
|
||||
err := pm.wsClient.SendMessage("olm/wg/unrelay", map[string]interface{}{
|
||||
chainId := generateChainId()
|
||||
stopFunc, _ := pm.wsClient.SendMessageInterval("olm/wg/unrelay", map[string]interface{}{
|
||||
"siteId": siteID,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to send registration message: %v", err)
|
||||
return err
|
||||
}
|
||||
logger.Info("Sent unrelay message")
|
||||
"chainId": chainId,
|
||||
}, 2*time.Second, 10)
|
||||
|
||||
pm.relaySendMu.Lock()
|
||||
pm.relaySends[chainId] = stopFunc
|
||||
pm.relaySendMu.Unlock()
|
||||
|
||||
logger.Info("Sent unrelay message for site %d (chain %s)", siteID, chainId)
|
||||
return nil
|
||||
}
|
||||
|
||||
// CancelRelaySend stops the interval sender for the given chainId, if one exists.
|
||||
// If chainId is empty, all active relay senders are stopped.
|
||||
func (pm *PeerMonitor) CancelRelaySend(chainId string) {
|
||||
pm.relaySendMu.Lock()
|
||||
defer pm.relaySendMu.Unlock()
|
||||
|
||||
if chainId == "" {
|
||||
for id, stop := range pm.relaySends {
|
||||
if stop != nil {
|
||||
stop()
|
||||
}
|
||||
delete(pm.relaySends, id)
|
||||
}
|
||||
logger.Info("Cancelled all relay senders")
|
||||
return
|
||||
}
|
||||
|
||||
if stop, ok := pm.relaySends[chainId]; ok {
|
||||
stop()
|
||||
delete(pm.relaySends, chainId)
|
||||
logger.Info("Cancelled relay sender for chain %s", chainId)
|
||||
} else {
|
||||
logger.Warn("CancelRelaySend: no active sender for chain %s", chainId)
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops monitoring all peers
|
||||
func (pm *PeerMonitor) Stop() {
|
||||
// Stop holepunch monitor first (outside of mutex to avoid deadlock)
|
||||
@@ -677,6 +722,16 @@ func (pm *PeerMonitor) Close() {
|
||||
// Stop holepunch monitor first (outside of mutex to avoid deadlock)
|
||||
pm.stopHolepunchMonitor()
|
||||
|
||||
// Stop all pending relay senders
|
||||
pm.relaySendMu.Lock()
|
||||
for chainId, stop := range pm.relaySends {
|
||||
if stop != nil {
|
||||
stop()
|
||||
}
|
||||
delete(pm.relaySends, chainId)
|
||||
}
|
||||
pm.relaySendMu.Unlock()
|
||||
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
|
||||
@@ -388,6 +388,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
|
||||
tokenData := map[string]interface{}{
|
||||
"olmId": c.config.ID,
|
||||
"secret": c.config.Secret,
|
||||
"userToken": c.config.UserToken,
|
||||
"orgId": c.config.OrgID,
|
||||
}
|
||||
jsonData, err := json.Marshal(tokenData)
|
||||
|
||||
Reference in New Issue
Block a user