Compare commits

..

15 Commits
main ... jit

Author SHA1 Message Date
Owen
c2b5ef96a4 Jit of aliases working 2026-03-12 17:26:46 -07:00
Owen
e326da3d3e Merge branch 'dev' into jit 2026-03-12 16:53:16 -07:00
Owen
53def4e2f6 Merge branch 'main' into dev 2026-03-12 16:51:06 -07:00
Owen
22cd02ae15 Alias jit handler 2026-03-11 15:56:51 -07:00
André Gilerson
3f258d3500 Fix crash when peer has nil publicKey in site config
Skip sites with empty/nil publicKey instead of passing them to the
WireGuard UAPI layer, which expects a valid 64-char hex string. A nil
key occurs when a Newt site has never connected. Previously this caused
all sites to fail with "hex string does not fit the slice".
2026-03-07 20:44:25 -08:00
Owen
e2690bcc03 Store site id 2026-03-06 16:19:00 -08:00
Owen
f2d0e6a14c Merge branch 'dev' into jit 2026-03-06 16:08:24 -08:00
Laurence
ae88766d85 test(dns): add dns test cases for nodata 2026-03-06 16:08:01 -08:00
Laurence
9ae49e36d5 refactor(dns): simplify DNSRecordStore from trie to map
Replace trie-based domain lookup with simple map for O(1) lookups.
  Add exists boolean to GetRecords for proper NODATA vs NXDOMAIN responses.
2026-03-06 16:08:01 -08:00
Laurence
5ca4825800 refactor(dns): trie + unified record set for DNSRecordStore
- Replace four maps (aRecords, aaaaRecords, aWildcards, aaaaWildcards) with a label trie for exact lookups and a single wildcards map
- Store one recordSet (A + AAAA) per domain/pattern instead of separate A and AAAA maps
- Exact lookups O(labels); PTR unchanged (map); API and behaviour unchanged
2026-03-06 16:08:01 -08:00
Owen
809dbe77de Make chainId in relay message bckwd compat 2026-03-06 15:27:03 -08:00
Owen
c67c2a60a1 Handle canceling sends for relay 2026-03-06 15:15:31 -08:00
Owen
051c0fdfd8 Working jit with chain ids 2026-03-04 17:51:48 -08:00
Owen
e7507e0837 Add api endpoints to jit 2026-03-04 17:01:17 -08:00
Owen
21b66fbb34 Update iss 2026-02-25 14:57:56 -08:00
12 changed files with 846 additions and 301 deletions

View File

@@ -78,6 +78,13 @@ type MetadataChangeRequest struct {
Postures map[string]any `json:"postures"` Postures map[string]any `json:"postures"`
} }
// JITConnectionRequest defines the structure for a dynamic Just-In-Time connection request.
// Either SiteID or ResourceID must be provided (but not necessarily both).
type JITConnectionRequest struct {
Site string `json:"site,omitempty"`
Resource string `json:"resource,omitempty"`
}
// API represents the HTTP server and its state // API represents the HTTP server and its state
type API struct { type API struct {
addr string addr string
@@ -92,6 +99,7 @@ type API struct {
onExit func() error onExit func() error
onRebind func() error onRebind func() error
onPowerMode func(PowerModeRequest) error onPowerMode func(PowerModeRequest) error
onJITConnect func(JITConnectionRequest) error
statusMu sync.RWMutex statusMu sync.RWMutex
peerStatuses map[int]*PeerStatus peerStatuses map[int]*PeerStatus
@@ -143,6 +151,7 @@ func (s *API) SetHandlers(
onExit func() error, onExit func() error,
onRebind func() error, onRebind func() error,
onPowerMode func(PowerModeRequest) error, onPowerMode func(PowerModeRequest) error,
onJITConnect func(JITConnectionRequest) error,
) { ) {
s.onConnect = onConnect s.onConnect = onConnect
s.onSwitchOrg = onSwitchOrg s.onSwitchOrg = onSwitchOrg
@@ -151,6 +160,7 @@ func (s *API) SetHandlers(
s.onExit = onExit s.onExit = onExit
s.onRebind = onRebind s.onRebind = onRebind
s.onPowerMode = onPowerMode s.onPowerMode = onPowerMode
s.onJITConnect = onJITConnect
} }
// Start starts the HTTP server // Start starts the HTTP server
@@ -169,6 +179,7 @@ func (s *API) Start() error {
mux.HandleFunc("/health", s.handleHealth) mux.HandleFunc("/health", s.handleHealth)
mux.HandleFunc("/rebind", s.handleRebind) mux.HandleFunc("/rebind", s.handleRebind)
mux.HandleFunc("/power-mode", s.handlePowerMode) mux.HandleFunc("/power-mode", s.handlePowerMode)
mux.HandleFunc("/jit-connect", s.handleJITConnect)
s.server = &http.Server{ s.server = &http.Server{
Handler: mux, Handler: mux,
@@ -633,6 +644,54 @@ func (s *API) handleRebind(w http.ResponseWriter, r *http.Request) {
}) })
} }
// handleJITConnect handles the /jit-connect endpoint.
// It initiates a dynamic Just-In-Time connection to a site identified by either
// a site or a resource. Exactly one of the two must be provided.
func (s *API) handleJITConnect(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req JITConnectionRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest)
return
}
// Validate that exactly one of site or resource is provided
if req.Site == "" && req.Resource == "" {
http.Error(w, "Missing required field: either site or resource must be provided", http.StatusBadRequest)
return
}
if req.Site != "" && req.Resource != "" {
http.Error(w, "Ambiguous request: provide either site or resource, not both", http.StatusBadRequest)
return
}
if req.Site != "" {
logger.Info("Received JIT connection request via API: site=%s", req.Site)
} else {
logger.Info("Received JIT connection request via API: resource=%s", req.Resource)
}
if s.onJITConnect != nil {
if err := s.onJITConnect(req); err != nil {
http.Error(w, fmt.Sprintf("JIT connection failed: %v", err), http.StatusInternalServerError)
return
}
} else {
http.Error(w, "JIT connect handler not configured", http.StatusNotImplemented)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusAccepted)
_ = json.NewEncoder(w).Encode(map[string]string{
"status": "JIT connection request accepted",
})
}
// handlePowerMode handles the /power-mode endpoint // handlePowerMode handles the /power-mode endpoint
// This allows changing the power mode between "normal" and "low" // This allows changing the power mode between "normal" and "low"
func (s *API) handlePowerMode(w http.ResponseWriter, r *http.Request) { func (s *API) handlePowerMode(w http.ResponseWriter, r *http.Request) {

View File

@@ -45,6 +45,11 @@ type DNSProxy struct {
tunnelActivePorts map[uint16]bool tunnelActivePorts map[uint16]bool
tunnelPortsLock sync.Mutex tunnelPortsLock sync.Mutex
// jitHandler is called when a local record is resolved for a site that may not be
// connected yet, giving the caller a chance to initiate a JIT connection.
// It is invoked asynchronously so it never blocks DNS resolution.
jitHandler func(siteId int)
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
wg sync.WaitGroup wg sync.WaitGroup
@@ -384,6 +389,16 @@ func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clie
response = p.checkLocalRecords(msg, question) response = p.checkLocalRecords(msg, question)
} }
// If a local A/AAAA record was found, notify the JIT handler so that the owning
// site can be connected on-demand if it is not yet active.
if response != nil && p.jitHandler != nil &&
(question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA) {
if siteId, ok := p.recordStore.GetSiteIdForDomain(question.Name); ok && siteId != 0 {
handler := p.jitHandler
go handler(siteId)
}
}
// If no local records, forward to upstream // If no local records, forward to upstream
if response == nil { if response == nil {
logger.Debug("No local record for %s, forwarding upstream to %v", question.Name, p.upstreamDNS) logger.Debug("No local record for %s, forwarding upstream to %v", question.Name, p.upstreamDNS)
@@ -447,19 +462,20 @@ func (p *DNSProxy) checkLocalRecords(query *dns.Msg, question dns.Question) *dns
return nil return nil
} }
ips := p.recordStore.GetRecords(question.Name, recordType) ips, exists := p.recordStore.GetRecords(question.Name, recordType)
if len(ips) == 0 { if !exists {
// Domain not found in local records, forward to upstream
return nil return nil
} }
logger.Debug("Found %d local record(s) for %s", len(ips), question.Name) logger.Debug("Found %d local record(s) for %s", len(ips), question.Name)
// Create response message // Create response message (NODATA if no records, otherwise with answers)
response := new(dns.Msg) response := new(dns.Msg)
response.SetReply(query) response.SetReply(query)
response.Authoritative = true response.Authoritative = true
// Add answer records // Add answer records (loop is a no-op if ips is empty)
for _, ip := range ips { for _, ip := range ips {
var rr dns.RR var rr dns.RR
if question.Qtype == dns.TypeA { if question.Qtype == dns.TypeA {
@@ -717,11 +733,20 @@ func (p *DNSProxy) runPacketSender() {
} }
} }
// SetJITHandler registers a callback that is invoked whenever a local DNS record is
// resolved for an A or AAAA query. The siteId identifies which site owns the record.
// The handler is called in its own goroutine so it must be safe to call concurrently.
// Pass nil to disable JIT notifications.
func (p *DNSProxy) SetJITHandler(handler func(siteId int)) {
p.jitHandler = handler
}
// AddDNSRecord adds a DNS record to the local store // AddDNSRecord adds a DNS record to the local store
// domain should be a domain name (e.g., "example.com" or "example.com.") // domain should be a domain name (e.g., "example.com" or "example.com.")
// ip should be a valid IPv4 or IPv6 address // ip should be a valid IPv4 or IPv6 address
func (p *DNSProxy) AddDNSRecord(domain string, ip net.IP) error { func (p *DNSProxy) AddDNSRecord(domain string, ip net.IP, siteId int) error {
return p.recordStore.AddRecord(domain, ip) logger.Debug("Adding dns record for domain %s with IP %s (siteId=%d)", domain, ip.String(), siteId)
return p.recordStore.AddRecord(domain, ip, siteId)
} }
// RemoveDNSRecord removes a DNS record from the local store // RemoveDNSRecord removes a DNS record from the local store
@@ -730,8 +755,9 @@ func (p *DNSProxy) RemoveDNSRecord(domain string, ip net.IP) {
p.recordStore.RemoveRecord(domain, ip) p.recordStore.RemoveRecord(domain, ip)
} }
// GetDNSRecords returns all IP addresses for a domain and record type // GetDNSRecords returns all IP addresses for a domain and record type.
func (p *DNSProxy) GetDNSRecords(domain string, recordType RecordType) []net.IP { // The second return value indicates whether the domain exists.
func (p *DNSProxy) GetDNSRecords(domain string, recordType RecordType) ([]net.IP, bool) {
return p.recordStore.GetRecords(domain, recordType) return p.recordStore.GetRecords(domain, recordType)
} }

178
dns/dns_proxy_test.go Normal file
View File

@@ -0,0 +1,178 @@
package dns
import (
"net"
"testing"
"github.com/miekg/dns"
)
func TestCheckLocalRecordsNODATAForAAAA(t *testing.T) {
proxy := &DNSProxy{
recordStore: NewDNSRecordStore(),
}
// Add an A record for a domain (no AAAA record)
ip := net.ParseIP("10.0.0.1")
err := proxy.recordStore.AddRecord("myservice.internal", ip, 0)
if err != nil {
t.Fatalf("Failed to add A record: %v", err)
}
// Query AAAA for domain with only A record - should return NODATA
query := new(dns.Msg)
query.SetQuestion("myservice.internal.", dns.TypeAAAA)
response := proxy.checkLocalRecords(query, query.Question[0])
if response == nil {
t.Fatal("Expected NODATA response, got nil (would forward to upstream)")
}
if response.Rcode != dns.RcodeSuccess {
t.Errorf("Expected Rcode NOERROR (0), got %d", response.Rcode)
}
if len(response.Answer) != 0 {
t.Errorf("Expected empty answer section for NODATA, got %d answers", len(response.Answer))
}
if !response.Authoritative {
t.Error("Expected response to be authoritative")
}
// Query A for same domain - should return the record
query = new(dns.Msg)
query.SetQuestion("myservice.internal.", dns.TypeA)
response = proxy.checkLocalRecords(query, query.Question[0])
if response == nil {
t.Fatal("Expected response with A record, got nil")
}
if len(response.Answer) != 1 {
t.Fatalf("Expected 1 answer, got %d", len(response.Answer))
}
aRecord, ok := response.Answer[0].(*dns.A)
if !ok {
t.Fatal("Expected A record in answer")
}
if !aRecord.A.Equal(ip.To4()) {
t.Errorf("Expected IP %v, got %v", ip.To4(), aRecord.A)
}
}
func TestCheckLocalRecordsNODATAForA(t *testing.T) {
proxy := &DNSProxy{
recordStore: NewDNSRecordStore(),
}
// Add an AAAA record for a domain (no A record)
ip := net.ParseIP("2001:db8::1")
err := proxy.recordStore.AddRecord("ipv6only.internal", ip, 0)
if err != nil {
t.Fatalf("Failed to add AAAA record: %v", err)
}
// Query A for domain with only AAAA record - should return NODATA
query := new(dns.Msg)
query.SetQuestion("ipv6only.internal.", dns.TypeA)
response := proxy.checkLocalRecords(query, query.Question[0])
if response == nil {
t.Fatal("Expected NODATA response, got nil")
}
if response.Rcode != dns.RcodeSuccess {
t.Errorf("Expected Rcode NOERROR (0), got %d", response.Rcode)
}
if len(response.Answer) != 0 {
t.Errorf("Expected empty answer section, got %d answers", len(response.Answer))
}
if !response.Authoritative {
t.Error("Expected response to be authoritative")
}
// Query AAAA for same domain - should return the record
query = new(dns.Msg)
query.SetQuestion("ipv6only.internal.", dns.TypeAAAA)
response = proxy.checkLocalRecords(query, query.Question[0])
if response == nil {
t.Fatal("Expected response with AAAA record, got nil")
}
if len(response.Answer) != 1 {
t.Fatalf("Expected 1 answer, got %d", len(response.Answer))
}
aaaaRecord, ok := response.Answer[0].(*dns.AAAA)
if !ok {
t.Fatal("Expected AAAA record in answer")
}
if !aaaaRecord.AAAA.Equal(ip) {
t.Errorf("Expected IP %v, got %v", ip, aaaaRecord.AAAA)
}
}
func TestCheckLocalRecordsNonExistentDomain(t *testing.T) {
proxy := &DNSProxy{
recordStore: NewDNSRecordStore(),
}
// Add a record so the store isn't empty
err := proxy.recordStore.AddRecord("exists.internal", net.ParseIP("10.0.0.1"), 0)
if err != nil {
t.Fatalf("Failed to add record: %v", err)
}
// Query A for non-existent domain - should return nil (forward to upstream)
query := new(dns.Msg)
query.SetQuestion("unknown.internal.", dns.TypeA)
response := proxy.checkLocalRecords(query, query.Question[0])
if response != nil {
t.Error("Expected nil for non-existent domain, got response")
}
// Query AAAA for non-existent domain - should also return nil
query = new(dns.Msg)
query.SetQuestion("unknown.internal.", dns.TypeAAAA)
response = proxy.checkLocalRecords(query, query.Question[0])
if response != nil {
t.Error("Expected nil for non-existent domain, got response")
}
}
func TestCheckLocalRecordsNODATAWildcard(t *testing.T) {
proxy := &DNSProxy{
recordStore: NewDNSRecordStore(),
}
// Add a wildcard A record (no AAAA)
ip := net.ParseIP("10.0.0.1")
err := proxy.recordStore.AddRecord("*.wildcard.internal", ip, 0)
if err != nil {
t.Fatalf("Failed to add wildcard A record: %v", err)
}
// Query AAAA for wildcard-matched domain - should return NODATA
query := new(dns.Msg)
query.SetQuestion("host.wildcard.internal.", dns.TypeAAAA)
response := proxy.checkLocalRecords(query, query.Question[0])
if response == nil {
t.Fatal("Expected NODATA response for wildcard match, got nil")
}
if response.Rcode != dns.RcodeSuccess {
t.Errorf("Expected Rcode NOERROR (0), got %d", response.Rcode)
}
if len(response.Answer) != 0 {
t.Errorf("Expected empty answer section, got %d answers", len(response.Answer))
}
// Query A for wildcard-matched domain - should return the record
query = new(dns.Msg)
query.SetQuestion("host.wildcard.internal.", dns.TypeA)
response = proxy.checkLocalRecords(query, query.Question[0])
if response == nil {
t.Fatal("Expected response with A record, got nil")
}
if len(response.Answer) != 1 {
t.Fatalf("Expected 1 answer, got %d", len(response.Answer))
}
}

View File

@@ -18,24 +18,28 @@ const (
RecordTypePTR RecordType = RecordType(dns.TypePTR) RecordTypePTR RecordType = RecordType(dns.TypePTR)
) )
// DNSRecordStore manages local DNS records for A, AAAA, and PTR queries // recordSet holds A and AAAA records for a single domain or wildcard pattern
type recordSet struct {
A []net.IP
AAAA []net.IP
SiteId int
}
// DNSRecordStore manages local DNS records for A, AAAA, and PTR queries.
// Exact domains are stored in a map; wildcard patterns are in a separate map.
type DNSRecordStore struct { type DNSRecordStore struct {
mu sync.RWMutex mu sync.RWMutex
aRecords map[string][]net.IP // domain -> list of IPv4 addresses exact map[string]*recordSet // normalized FQDN -> A/AAAA records
aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses wildcards map[string]*recordSet // wildcard pattern -> A/AAAA records
aWildcards map[string][]net.IP // wildcard pattern -> list of IPv4 addresses ptrRecords map[string]string // IP address string -> domain name
aaaaWildcards map[string][]net.IP // wildcard pattern -> list of IPv6 addresses
ptrRecords map[string]string // IP address string -> domain name
} }
// NewDNSRecordStore creates a new DNS record store // NewDNSRecordStore creates a new DNS record store
func NewDNSRecordStore() *DNSRecordStore { func NewDNSRecordStore() *DNSRecordStore {
return &DNSRecordStore{ return &DNSRecordStore{
aRecords: make(map[string][]net.IP), exact: make(map[string]*recordSet),
aaaaRecords: make(map[string][]net.IP), wildcards: make(map[string]*recordSet),
aWildcards: make(map[string][]net.IP), ptrRecords: make(map[string]string),
aaaaWildcards: make(map[string][]net.IP),
ptrRecords: make(map[string]string),
} }
} }
@@ -43,47 +47,57 @@ func NewDNSRecordStore() *DNSRecordStore {
// domain should be in FQDN format (e.g., "example.com.") // domain should be in FQDN format (e.g., "example.com.")
// domain can contain wildcards: * (0+ chars) and ? (exactly 1 char) // domain can contain wildcards: * (0+ chars) and ? (exactly 1 char)
// ip should be a valid IPv4 or IPv6 address // ip should be a valid IPv4 or IPv6 address
// siteId is the site that owns this alias/domain
// Automatically adds a corresponding PTR record for non-wildcard domains // Automatically adds a corresponding PTR record for non-wildcard domains
func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error { func (s *DNSRecordStore) AddRecord(domain string, ip net.IP, siteId int) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
// Ensure domain ends with a dot (FQDN format)
if len(domain) == 0 || domain[len(domain)-1] != '.' { if len(domain) == 0 || domain[len(domain)-1] != '.' {
domain = domain + "." domain = domain + "."
} }
// Normalize domain to lowercase FQDN
domain = strings.ToLower(dns.Fqdn(domain)) domain = strings.ToLower(dns.Fqdn(domain))
// Check if domain contains wildcards
isWildcard := strings.ContainsAny(domain, "*?") isWildcard := strings.ContainsAny(domain, "*?")
if ip.To4() != nil { isV4 := ip.To4() != nil
// IPv4 address if !isV4 && ip.To16() == nil {
if isWildcard {
s.aWildcards[domain] = append(s.aWildcards[domain], ip)
} else {
s.aRecords[domain] = append(s.aRecords[domain], ip)
// Automatically add PTR record for non-wildcard domains
s.ptrRecords[ip.String()] = domain
}
} else if ip.To16() != nil {
// IPv6 address
if isWildcard {
s.aaaaWildcards[domain] = append(s.aaaaWildcards[domain], ip)
} else {
s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip)
// Automatically add PTR record for non-wildcard domains
s.ptrRecords[ip.String()] = domain
}
} else {
return &net.ParseError{Type: "IP address", Text: ip.String()} return &net.ParseError{Type: "IP address", Text: ip.String()}
} }
// Choose the appropriate map based on whether this is a wildcard
m := s.exact
if isWildcard {
m = s.wildcards
}
if m[domain] == nil {
m[domain] = &recordSet{SiteId: siteId}
}
rs := m[domain]
if isV4 {
for _, existing := range rs.A {
if existing.Equal(ip) {
return nil
}
}
rs.A = append(rs.A, ip)
} else {
for _, existing := range rs.AAAA {
if existing.Equal(ip) {
return nil
}
}
rs.AAAA = append(rs.AAAA, ip)
}
// Add PTR record for non-wildcard domains
if !isWildcard {
s.ptrRecords[ip.String()] = domain
}
return nil return nil
} }
// AddPTRRecord adds a PTR record mapping an IP address to a domain name // AddPTRRecord adds a PTR record mapping an IP address to a domain name
// ip should be a valid IPv4 or IPv6 address // ip should be a valid IPv4 or IPv6 address
// domain should be in FQDN format (e.g., "example.com.") // domain should be in FQDN format (e.g., "example.com.")
@@ -112,89 +126,62 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
// Ensure domain ends with a dot (FQDN format)
if len(domain) == 0 || domain[len(domain)-1] != '.' { if len(domain) == 0 || domain[len(domain)-1] != '.' {
domain = domain + "." domain = domain + "."
} }
// Normalize domain to lowercase FQDN
domain = strings.ToLower(dns.Fqdn(domain)) domain = strings.ToLower(dns.Fqdn(domain))
// Check if domain contains wildcards
isWildcard := strings.ContainsAny(domain, "*?") isWildcard := strings.ContainsAny(domain, "*?")
if ip == nil { // Choose the appropriate map
// Remove all records for this domain m := s.exact
if isWildcard { if isWildcard {
delete(s.aWildcards, domain) m = s.wildcards
delete(s.aaaaWildcards, domain) }
} else {
// For non-wildcard domains, remove PTR records for all IPs rs := m[domain]
if ips, ok := s.aRecords[domain]; ok { if rs == nil {
for _, ipAddr := range ips {
// Only remove PTR if it points to this domain
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ipAddr.String())
}
}
}
if ips, ok := s.aaaaRecords[domain]; ok {
for _, ipAddr := range ips {
// Only remove PTR if it points to this domain
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ipAddr.String())
}
}
}
delete(s.aRecords, domain)
delete(s.aaaaRecords, domain)
}
return return
} }
if ip == nil {
// Remove all records for this domain
if !isWildcard {
for _, ipAddr := range rs.A {
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ipAddr.String())
}
}
for _, ipAddr := range rs.AAAA {
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ipAddr.String())
}
}
}
delete(m, domain)
return
}
// Remove specific IP
if ip.To4() != nil { if ip.To4() != nil {
// Remove specific IPv4 address rs.A = removeIP(rs.A, ip)
if isWildcard { if !isWildcard {
if ips, ok := s.aWildcards[domain]; ok { if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
s.aWildcards[domain] = removeIP(ips, ip) delete(s.ptrRecords, ip.String())
if len(s.aWildcards[domain]) == 0 {
delete(s.aWildcards, domain)
}
}
} else {
if ips, ok := s.aRecords[domain]; ok {
s.aRecords[domain] = removeIP(ips, ip)
if len(s.aRecords[domain]) == 0 {
delete(s.aRecords, domain)
}
// Automatically remove PTR record if it points to this domain
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ip.String())
}
} }
} }
} else if ip.To16() != nil { } else {
// Remove specific IPv6 address rs.AAAA = removeIP(rs.AAAA, ip)
if isWildcard { if !isWildcard {
if ips, ok := s.aaaaWildcards[domain]; ok { if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
s.aaaaWildcards[domain] = removeIP(ips, ip) delete(s.ptrRecords, ip.String())
if len(s.aaaaWildcards[domain]) == 0 {
delete(s.aaaaWildcards, domain)
}
}
} else {
if ips, ok := s.aaaaRecords[domain]; ok {
s.aaaaRecords[domain] = removeIP(ips, ip)
if len(s.aaaaRecords[domain]) == 0 {
delete(s.aaaaRecords, domain)
}
// Automatically remove PTR record if it points to this domain
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ip.String())
}
} }
} }
} }
// Clean up empty record sets
if len(rs.A) == 0 && len(rs.AAAA) == 0 {
delete(m, domain)
}
} }
// RemovePTRRecord removes a PTR record for an IP address // RemovePTRRecord removes a PTR record for an IP address
@@ -205,61 +192,80 @@ func (s *DNSRecordStore) RemovePTRRecord(ip net.IP) {
delete(s.ptrRecords, ip.String()) delete(s.ptrRecords, ip.String())
} }
// GetRecords returns all IP addresses for a domain and record type // GetSiteIdForDomain returns the siteId associated with the given domain.
// First checks for exact matches, then checks wildcard patterns // It checks exact matches first, then wildcard patterns.
func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP { // The second return value is false if the domain is not found in local records.
func (s *DNSRecordStore) GetSiteIdForDomain(domain string) (int, bool) {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
// Normalize domain to lowercase FQDN
domain = strings.ToLower(dns.Fqdn(domain)) domain = strings.ToLower(dns.Fqdn(domain))
var records []net.IP // Check exact match first
switch recordType { if rs, exists := s.exact[domain]; exists {
case RecordTypeA: return rs.SiteId, true
// Check exact match first }
if ips, ok := s.aRecords[domain]; ok {
// Return a copy to prevent external modifications
records = make([]net.IP, len(ips))
copy(records, ips)
return records
}
// Check wildcard patterns
for pattern, ips := range s.aWildcards {
if matchWildcard(pattern, domain) {
records = append(records, ips...)
}
}
if len(records) > 0 {
// Return a copy
result := make([]net.IP, len(records))
copy(result, records)
return result
}
case RecordTypeAAAA: // Check wildcard matches
// Check exact match first for pattern, rs := range s.wildcards {
if ips, ok := s.aaaaRecords[domain]; ok { if matchWildcard(pattern, domain) {
// Return a copy to prevent external modifications return rs.SiteId, true
records = make([]net.IP, len(ips))
copy(records, ips)
return records
}
// Check wildcard patterns
for pattern, ips := range s.aaaaWildcards {
if matchWildcard(pattern, domain) {
records = append(records, ips...)
}
}
if len(records) > 0 {
// Return a copy
result := make([]net.IP, len(records))
copy(result, records)
return result
} }
} }
return records return 0, false
}
// GetRecords returns all IP addresses for a domain and record type.
// The second return value indicates whether the domain exists at all
// (true = domain exists, use NODATA if no records; false = NXDOMAIN).
func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) ([]net.IP, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
domain = strings.ToLower(dns.Fqdn(domain))
// Check exact match first
if rs, exists := s.exact[domain]; exists {
var ips []net.IP
if recordType == RecordTypeA {
ips = rs.A
} else {
ips = rs.AAAA
}
if len(ips) > 0 {
out := make([]net.IP, len(ips))
copy(out, ips)
return out, true
}
// Domain exists but no records of this type
return nil, true
}
// Check wildcard matches
var records []net.IP
matched := false
for pattern, rs := range s.wildcards {
if !matchWildcard(pattern, domain) {
continue
}
matched = true
if recordType == RecordTypeA {
records = append(records, rs.A...)
} else {
records = append(records, rs.AAAA...)
}
}
if !matched {
return nil, false
}
if len(records) == 0 {
return nil, true
}
out := make([]net.IP, len(records))
copy(out, records)
return out, true
} }
// GetPTRRecord returns the domain name for a PTR record query // GetPTRRecord returns the domain name for a PTR record query
@@ -288,34 +294,30 @@ func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
// Normalize domain to lowercase FQDN
domain = strings.ToLower(dns.Fqdn(domain)) domain = strings.ToLower(dns.Fqdn(domain))
switch recordType { // Check exact match
case RecordTypeA: if rs, exists := s.exact[domain]; exists {
// Check exact match if recordType == RecordTypeA && len(rs.A) > 0 {
if _, ok := s.aRecords[domain]; ok {
return true return true
} }
// Check wildcard patterns if recordType == RecordTypeAAAA && len(rs.AAAA) > 0 {
for pattern := range s.aWildcards {
if matchWildcard(pattern, domain) {
return true
}
}
case RecordTypeAAAA:
// Check exact match
if _, ok := s.aaaaRecords[domain]; ok {
return true return true
} }
// Check wildcard patterns
for pattern := range s.aaaaWildcards {
if matchWildcard(pattern, domain) {
return true
}
}
} }
// Check wildcard matches
for pattern, rs := range s.wildcards {
if !matchWildcard(pattern, domain) {
continue
}
if recordType == RecordTypeA && len(rs.A) > 0 {
return true
}
if recordType == RecordTypeAAAA && len(rs.AAAA) > 0 {
return true
}
}
return false return false
} }
@@ -339,10 +341,8 @@ func (s *DNSRecordStore) Clear() {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
s.aRecords = make(map[string][]net.IP) s.exact = make(map[string]*recordSet)
s.aaaaRecords = make(map[string][]net.IP) s.wildcards = make(map[string]*recordSet)
s.aWildcards = make(map[string][]net.IP)
s.aaaaWildcards = make(map[string][]net.IP)
s.ptrRecords = make(map[string]string) s.ptrRecords = make(map[string]string)
} }
@@ -494,4 +494,4 @@ func IPToReverseDNS(ip net.IP) string {
} }
return "" return ""
} }

View File

@@ -170,38 +170,47 @@ func TestDNSRecordStoreWildcard(t *testing.T) {
// Add wildcard records // Add wildcard records
wildcardIP := net.ParseIP("10.0.0.1") wildcardIP := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.autoco.internal", wildcardIP) err := store.AddRecord("*.autoco.internal", wildcardIP, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err) t.Fatalf("Failed to add wildcard record: %v", err)
} }
// Add exact record // Add exact record
exactIP := net.ParseIP("10.0.0.2") exactIP := net.ParseIP("10.0.0.2")
err = store.AddRecord("exact.autoco.internal", exactIP) err = store.AddRecord("exact.autoco.internal", exactIP, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add exact record: %v", err) t.Fatalf("Failed to add exact record: %v", err)
} }
// Test exact match takes precedence // Test exact match takes precedence
ips := store.GetRecords("exact.autoco.internal.", RecordTypeA) ips, exists := store.GetRecords("exact.autoco.internal.", RecordTypeA)
if !exists {
t.Error("Expected domain to exist")
}
if len(ips) != 1 { if len(ips) != 1 {
t.Errorf("Expected 1 IP for exact match, got %d", len(ips)) t.Errorf("Expected 1 IP for exact match, got %d", len(ips))
} }
if !ips[0].Equal(exactIP) { if len(ips) > 0 && !ips[0].Equal(exactIP) {
t.Errorf("Expected exact IP %v, got %v", exactIP, ips[0]) t.Errorf("Expected exact IP %v, got %v", exactIP, ips[0])
} }
// Test wildcard match // Test wildcard match
ips = store.GetRecords("host.autoco.internal.", RecordTypeA) ips, exists = store.GetRecords("host.autoco.internal.", RecordTypeA)
if !exists {
t.Error("Expected wildcard match to exist")
}
if len(ips) != 1 { if len(ips) != 1 {
t.Errorf("Expected 1 IP for wildcard match, got %d", len(ips)) t.Errorf("Expected 1 IP for wildcard match, got %d", len(ips))
} }
if !ips[0].Equal(wildcardIP) { if len(ips) > 0 && !ips[0].Equal(wildcardIP) {
t.Errorf("Expected wildcard IP %v, got %v", wildcardIP, ips[0]) t.Errorf("Expected wildcard IP %v, got %v", wildcardIP, ips[0])
} }
// Test non-match (base domain) // Test non-match (base domain)
ips = store.GetRecords("autoco.internal.", RecordTypeA) ips, exists = store.GetRecords("autoco.internal.", RecordTypeA)
if exists {
t.Error("Expected base domain to not exist")
}
if len(ips) != 0 { if len(ips) != 0 {
t.Errorf("Expected 0 IPs for base domain, got %d", len(ips)) t.Errorf("Expected 0 IPs for base domain, got %d", len(ips))
} }
@@ -212,13 +221,16 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) {
// Add complex wildcard pattern // Add complex wildcard pattern
ip1 := net.ParseIP("10.0.0.1") ip1 := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.host-0?.autoco.internal", ip1) err := store.AddRecord("*.host-0?.autoco.internal", ip1, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err) t.Fatalf("Failed to add wildcard record: %v", err)
} }
// Test matching domain // Test matching domain
ips := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA) ips, exists := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA)
if !exists {
t.Error("Expected complex wildcard match to exist")
}
if len(ips) != 1 { if len(ips) != 1 {
t.Errorf("Expected 1 IP for complex wildcard match, got %d", len(ips)) t.Errorf("Expected 1 IP for complex wildcard match, got %d", len(ips))
} }
@@ -227,13 +239,19 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) {
} }
// Test non-matching domain (missing prefix) // Test non-matching domain (missing prefix)
ips = store.GetRecords("host-01.autoco.internal.", RecordTypeA) ips, exists = store.GetRecords("host-01.autoco.internal.", RecordTypeA)
if exists {
t.Error("Expected domain without prefix to not exist")
}
if len(ips) != 0 { if len(ips) != 0 {
t.Errorf("Expected 0 IPs for domain without prefix, got %d", len(ips)) t.Errorf("Expected 0 IPs for domain without prefix, got %d", len(ips))
} }
// Test non-matching domain (wrong ? position) // Test non-matching domain (wrong ? position)
ips = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA) ips, exists = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA)
if exists {
t.Error("Expected domain with wrong ? match to not exist")
}
if len(ips) != 0 { if len(ips) != 0 {
t.Errorf("Expected 0 IPs for domain with wrong ? match, got %d", len(ips)) t.Errorf("Expected 0 IPs for domain with wrong ? match, got %d", len(ips))
} }
@@ -244,13 +262,16 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) {
// Add wildcard record // Add wildcard record
ip := net.ParseIP("10.0.0.1") ip := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.autoco.internal", ip) err := store.AddRecord("*.autoco.internal", ip, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err) t.Fatalf("Failed to add wildcard record: %v", err)
} }
// Verify it exists // Verify it exists
ips := store.GetRecords("host.autoco.internal.", RecordTypeA) ips, exists := store.GetRecords("host.autoco.internal.", RecordTypeA)
if !exists {
t.Error("Expected domain to exist before removal")
}
if len(ips) != 1 { if len(ips) != 1 {
t.Errorf("Expected 1 IP before removal, got %d", len(ips)) t.Errorf("Expected 1 IP before removal, got %d", len(ips))
} }
@@ -259,7 +280,10 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) {
store.RemoveRecord("*.autoco.internal", nil) store.RemoveRecord("*.autoco.internal", nil)
// Verify it's gone // Verify it's gone
ips = store.GetRecords("host.autoco.internal.", RecordTypeA) ips, exists = store.GetRecords("host.autoco.internal.", RecordTypeA)
if exists {
t.Error("Expected domain to not exist after removal")
}
if len(ips) != 0 { if len(ips) != 0 {
t.Errorf("Expected 0 IPs after removal, got %d", len(ips)) t.Errorf("Expected 0 IPs after removal, got %d", len(ips))
} }
@@ -273,36 +297,36 @@ func TestDNSRecordStoreMultipleWildcards(t *testing.T) {
ip2 := net.ParseIP("10.0.0.2") ip2 := net.ParseIP("10.0.0.2")
ip3 := net.ParseIP("10.0.0.3") ip3 := net.ParseIP("10.0.0.3")
err := store.AddRecord("*.prod.autoco.internal", ip1) err := store.AddRecord("*.prod.autoco.internal", ip1, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add first wildcard: %v", err) t.Fatalf("Failed to add first wildcard: %v", err)
} }
err = store.AddRecord("*.dev.autoco.internal", ip2) err = store.AddRecord("*.dev.autoco.internal", ip2, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add second wildcard: %v", err) t.Fatalf("Failed to add second wildcard: %v", err)
} }
// Add a broader wildcard that matches both // Add a broader wildcard that matches both
err = store.AddRecord("*.autoco.internal", ip3) err = store.AddRecord("*.autoco.internal", ip3, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add third wildcard: %v", err) t.Fatalf("Failed to add third wildcard: %v", err)
} }
// Test domain matching only the prod pattern and the broad pattern // Test domain matching only the prod pattern and the broad pattern
ips := store.GetRecords("host.prod.autoco.internal.", RecordTypeA) ips, _ := store.GetRecords("host.prod.autoco.internal.", RecordTypeA)
if len(ips) != 2 { if len(ips) != 2 {
t.Errorf("Expected 2 IPs (prod + broad), got %d", len(ips)) t.Errorf("Expected 2 IPs (prod + broad), got %d", len(ips))
} }
// Test domain matching only the dev pattern and the broad pattern // Test domain matching only the dev pattern and the broad pattern
ips = store.GetRecords("service.dev.autoco.internal.", RecordTypeA) ips, _ = store.GetRecords("service.dev.autoco.internal.", RecordTypeA)
if len(ips) != 2 { if len(ips) != 2 {
t.Errorf("Expected 2 IPs (dev + broad), got %d", len(ips)) t.Errorf("Expected 2 IPs (dev + broad), got %d", len(ips))
} }
// Test domain matching only the broad pattern // Test domain matching only the broad pattern
ips = store.GetRecords("host.test.autoco.internal.", RecordTypeA) ips, _ = store.GetRecords("host.test.autoco.internal.", RecordTypeA)
if len(ips) != 1 { if len(ips) != 1 {
t.Errorf("Expected 1 IP (broad only), got %d", len(ips)) t.Errorf("Expected 1 IP (broad only), got %d", len(ips))
} }
@@ -313,13 +337,13 @@ func TestDNSRecordStoreIPv6Wildcard(t *testing.T) {
// Add IPv6 wildcard record // Add IPv6 wildcard record
ip := net.ParseIP("2001:db8::1") ip := net.ParseIP("2001:db8::1")
err := store.AddRecord("*.autoco.internal", ip) err := store.AddRecord("*.autoco.internal", ip, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add IPv6 wildcard record: %v", err) t.Fatalf("Failed to add IPv6 wildcard record: %v", err)
} }
// Test wildcard match for IPv6 // Test wildcard match for IPv6
ips := store.GetRecords("host.autoco.internal.", RecordTypeAAAA) ips, _ := store.GetRecords("host.autoco.internal.", RecordTypeAAAA)
if len(ips) != 1 { if len(ips) != 1 {
t.Errorf("Expected 1 IPv6 for wildcard match, got %d", len(ips)) t.Errorf("Expected 1 IPv6 for wildcard match, got %d", len(ips))
} }
@@ -333,7 +357,7 @@ func TestHasRecordWildcard(t *testing.T) {
// Add wildcard record // Add wildcard record
ip := net.ParseIP("10.0.0.1") ip := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.autoco.internal", ip) err := store.AddRecord("*.autoco.internal", ip, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err) t.Fatalf("Failed to add wildcard record: %v", err)
} }
@@ -354,7 +378,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
// Add record with mixed case // Add record with mixed case
ip := net.ParseIP("10.0.0.1") ip := net.ParseIP("10.0.0.1")
err := store.AddRecord("MyHost.AutoCo.Internal", ip) err := store.AddRecord("MyHost.AutoCo.Internal", ip, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add mixed case record: %v", err) t.Fatalf("Failed to add mixed case record: %v", err)
} }
@@ -368,7 +392,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
} }
for _, domain := range testCases { for _, domain := range testCases {
ips := store.GetRecords(domain, RecordTypeA) ips, _ := store.GetRecords(domain, RecordTypeA)
if len(ips) != 1 { if len(ips) != 1 {
t.Errorf("Expected 1 IP for domain %q, got %d", domain, len(ips)) t.Errorf("Expected 1 IP for domain %q, got %d", domain, len(ips))
} }
@@ -379,7 +403,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
// Test wildcard with mixed case // Test wildcard with mixed case
wildcardIP := net.ParseIP("10.0.0.2") wildcardIP := net.ParseIP("10.0.0.2")
err = store.AddRecord("*.Example.Com", wildcardIP) err = store.AddRecord("*.Example.Com", wildcardIP, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add mixed case wildcard: %v", err) t.Fatalf("Failed to add mixed case wildcard: %v", err)
} }
@@ -392,7 +416,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
} }
for _, domain := range wildcardTestCases { for _, domain := range wildcardTestCases {
ips := store.GetRecords(domain, RecordTypeA) ips, _ := store.GetRecords(domain, RecordTypeA)
if len(ips) != 1 { if len(ips) != 1 {
t.Errorf("Expected 1 IP for wildcard domain %q, got %d", domain, len(ips)) t.Errorf("Expected 1 IP for wildcard domain %q, got %d", domain, len(ips))
} }
@@ -403,7 +427,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
// Test removal with different case // Test removal with different case
store.RemoveRecord("MYHOST.AUTOCO.INTERNAL", nil) store.RemoveRecord("MYHOST.AUTOCO.INTERNAL", nil)
ips := store.GetRecords("myhost.autoco.internal.", RecordTypeA) ips, _ := store.GetRecords("myhost.autoco.internal.", RecordTypeA)
if len(ips) != 0 { if len(ips) != 0 {
t.Errorf("Expected 0 IPs after removal, got %d", len(ips)) t.Errorf("Expected 0 IPs after removal, got %d", len(ips))
} }
@@ -665,7 +689,7 @@ func TestClearPTRRecords(t *testing.T) {
store.AddPTRRecord(ip2, "host2.example.com.") store.AddPTRRecord(ip2, "host2.example.com.")
// Add some A records too // Add some A records too
store.AddRecord("test.example.com.", net.ParseIP("10.0.0.1")) store.AddRecord("test.example.com.", net.ParseIP("10.0.0.1"), 0)
// Verify PTR records exist // Verify PTR records exist
if !store.HasPTRRecord("1.1.168.192.in-addr.arpa.") { if !store.HasPTRRecord("1.1.168.192.in-addr.arpa.") {
@@ -695,7 +719,7 @@ func TestAutomaticPTRRecordOnAdd(t *testing.T) {
// Add an A record - should automatically add PTR record // Add an A record - should automatically add PTR record
domain := "host.example.com." domain := "host.example.com."
ip := net.ParseIP("192.168.1.100") ip := net.ParseIP("192.168.1.100")
err := store.AddRecord(domain, ip) err := store.AddRecord(domain, ip, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add A record: %v", err) t.Fatalf("Failed to add A record: %v", err)
} }
@@ -713,7 +737,7 @@ func TestAutomaticPTRRecordOnAdd(t *testing.T) {
// Add AAAA record - should also automatically add PTR record // Add AAAA record - should also automatically add PTR record
domain6 := "ipv6host.example.com." domain6 := "ipv6host.example.com."
ip6 := net.ParseIP("2001:db8::1") ip6 := net.ParseIP("2001:db8::1")
err = store.AddRecord(domain6, ip6) err = store.AddRecord(domain6, ip6, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add AAAA record: %v", err) t.Fatalf("Failed to add AAAA record: %v", err)
} }
@@ -735,7 +759,7 @@ func TestAutomaticPTRRecordOnRemove(t *testing.T) {
// Add an A record (with automatic PTR) // Add an A record (with automatic PTR)
domain := "host.example.com." domain := "host.example.com."
ip := net.ParseIP("192.168.1.100") ip := net.ParseIP("192.168.1.100")
store.AddRecord(domain, ip) store.AddRecord(domain, ip, 0)
// Verify PTR exists // Verify PTR exists
reverseDomain := "100.1.168.192.in-addr.arpa." reverseDomain := "100.1.168.192.in-addr.arpa."
@@ -752,7 +776,7 @@ func TestAutomaticPTRRecordOnRemove(t *testing.T) {
} }
// Verify A record is also gone // Verify A record is also gone
ips := store.GetRecords(domain, RecordTypeA) ips, _ := store.GetRecords(domain, RecordTypeA)
if len(ips) != 0 { if len(ips) != 0 {
t.Errorf("Expected A record to be removed, got %d records", len(ips)) t.Errorf("Expected A record to be removed, got %d records", len(ips))
} }
@@ -765,8 +789,8 @@ func TestAutomaticPTRRecordOnRemoveAll(t *testing.T) {
domain := "host.example.com." domain := "host.example.com."
ip1 := net.ParseIP("192.168.1.100") ip1 := net.ParseIP("192.168.1.100")
ip2 := net.ParseIP("192.168.1.101") ip2 := net.ParseIP("192.168.1.101")
store.AddRecord(domain, ip1) store.AddRecord(domain, ip1, 0)
store.AddRecord(domain, ip2) store.AddRecord(domain, ip2, 0)
// Verify both PTR records exist // Verify both PTR records exist
reverseDomain1 := "100.1.168.192.in-addr.arpa." reverseDomain1 := "100.1.168.192.in-addr.arpa."
@@ -796,7 +820,7 @@ func TestNoPTRForWildcardRecords(t *testing.T) {
// Add wildcard record - should NOT create PTR record // Add wildcard record - should NOT create PTR record
domain := "*.example.com." domain := "*.example.com."
ip := net.ParseIP("192.168.1.100") ip := net.ParseIP("192.168.1.100")
err := store.AddRecord(domain, ip) err := store.AddRecord(domain, ip, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err) t.Fatalf("Failed to add wildcard record: %v", err)
} }
@@ -820,7 +844,7 @@ func TestPTRRecordOverwrite(t *testing.T) {
// Add first domain with IP // Add first domain with IP
domain1 := "host1.example.com." domain1 := "host1.example.com."
ip := net.ParseIP("192.168.1.100") ip := net.ParseIP("192.168.1.100")
store.AddRecord(domain1, ip) store.AddRecord(domain1, ip, 0)
// Verify PTR points to first domain // Verify PTR points to first domain
reverseDomain := "100.1.168.192.in-addr.arpa." reverseDomain := "100.1.168.192.in-addr.arpa."
@@ -834,7 +858,7 @@ func TestPTRRecordOverwrite(t *testing.T) {
// Add second domain with same IP - should overwrite PTR // Add second domain with same IP - should overwrite PTR
domain2 := "host2.example.com." domain2 := "host2.example.com."
store.AddRecord(domain2, ip) store.AddRecord(domain2, ip, 0)
// Verify PTR now points to second domain (last one added) // Verify PTR now points to second domain (last one added)
result, ok = store.GetPTRRecord(reverseDomain) result, ok = store.GetPTRRecord(reverseDomain)

14
olm.iss
View File

@@ -32,7 +32,7 @@ DefaultGroupName={#MyAppName}
DisableProgramGroupPage=yes DisableProgramGroupPage=yes
; Uncomment the following line to run in non administrative install mode (install for current user only). ; Uncomment the following line to run in non administrative install mode (install for current user only).
;PrivilegesRequired=lowest ;PrivilegesRequired=lowest
OutputBaseFilename=mysetup OutputBaseFilename=olm_windows_installer
SolidCompression=yes SolidCompression=yes
WizardStyle=modern WizardStyle=modern
; Add this to ensure PATH changes are applied and the system is prompted for a restart if needed ; Add this to ensure PATH changes are applied and the system is prompted for a restart if needed
@@ -78,7 +78,7 @@ begin
Result := True; Result := True;
exit; exit;
end; end;
// Perform a case-insensitive check to see if the path is already present. // Perform a case-insensitive check to see if the path is already present.
// We add semicolons to prevent partial matches (e.g., matching C:\App in C:\App2). // We add semicolons to prevent partial matches (e.g., matching C:\App in C:\App2).
if Pos(';' + UpperCase(Path) + ';', ';' + UpperCase(OrigPath) + ';') > 0 then if Pos(';' + UpperCase(Path) + ';', ';' + UpperCase(OrigPath) + ';') > 0 then
@@ -109,7 +109,7 @@ begin
PathList.Delimiter := ';'; PathList.Delimiter := ';';
PathList.StrictDelimiter := True; PathList.StrictDelimiter := True;
PathList.DelimitedText := OrigPath; PathList.DelimitedText := OrigPath;
// Find and remove the matching entry (case-insensitive) // Find and remove the matching entry (case-insensitive)
for I := PathList.Count - 1 downto 0 do for I := PathList.Count - 1 downto 0 do
begin begin
@@ -119,10 +119,10 @@ begin
PathList.Delete(I); PathList.Delete(I);
end; end;
end; end;
// Reconstruct the PATH // Reconstruct the PATH
NewPath := PathList.DelimitedText; NewPath := PathList.DelimitedText;
// Write the new PATH back to the registry // Write the new PATH back to the registry
if RegWriteExpandStringValue(HKEY_LOCAL_MACHINE, if RegWriteExpandStringValue(HKEY_LOCAL_MACHINE,
'SYSTEM\CurrentControlSet\Control\Session Manager\Environment', 'SYSTEM\CurrentControlSet\Control\Session Manager\Environment',
@@ -145,8 +145,8 @@ begin
// Get the application installation path // Get the application installation path
AppPath := ExpandConstant('{app}'); AppPath := ExpandConstant('{app}');
Log('Removing PATH entry for: ' + AppPath); Log('Removing PATH entry for: ' + AppPath);
// Remove only our path entry from the system PATH // Remove only our path entry from the system PATH
RemovePathEntry(AppPath); RemovePathEntry(AppPath);
end; end;
end; end;

View File

@@ -7,6 +7,7 @@ import (
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/fosrl/newt/logger" "github.com/fosrl/newt/logger"
"github.com/fosrl/newt/network" "github.com/fosrl/newt/network"
@@ -173,16 +174,20 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) {
for i := range wgData.Sites { for i := range wgData.Sites {
site := wgData.Sites[i] site := wgData.Sites[i]
var siteEndpoint string
// here we are going to take the relay endpoint if it exists which means we requested a relay for this peer if site.PublicKey != "" {
if site.RelayEndpoint != "" { var siteEndpoint string
siteEndpoint = site.RelayEndpoint // here we are going to take the relay endpoint if it exists which means we requested a relay for this peer
} else { if site.RelayEndpoint != "" {
siteEndpoint = site.Endpoint siteEndpoint = site.RelayEndpoint
} else {
siteEndpoint = site.Endpoint
}
o.apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false)
} }
o.apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false) // we still call this to add the aliases for jit lookup but we just do that then pass inside. need to skip the above so we dont add to the api
if err := o.peerManager.AddPeer(site); err != nil { if err := o.peerManager.AddPeer(site); err != nil {
logger.Error("Failed to add peer: %v", err) logger.Error("Failed to add peer: %v", err)
return return
@@ -197,6 +202,36 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) {
logger.Error("Failed to start DNS proxy: %v", err) logger.Error("Failed to start DNS proxy: %v", err)
} }
// Register JIT handler: when the DNS proxy resolves a local record, check whether
// the owning site is already connected and, if not, initiate a JIT connection.
o.dnsProxy.SetJITHandler(func(siteId int) {
if o.peerManager == nil || o.websocket == nil {
return
}
// Site already has an active peer connection - nothing to do.
if _, exists := o.peerManager.GetPeer(siteId); exists {
return
}
o.peerSendMu.Lock()
defer o.peerSendMu.Unlock()
// A JIT request for this site is already in-flight - avoid duplicate sends.
if _, pending := o.jitPendingSites[siteId]; pending {
return
}
chainId := generateChainId()
logger.Info("DNS-triggered JIT connect for site %d (chainId=%s)", siteId, chainId)
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/init", map[string]interface{}{
"siteId": siteId,
"chainId": chainId,
}, 2*time.Second, 10)
o.stopPeerInits[chainId] = stopFunc
o.jitPendingSites[siteId] = chainId
})
if o.tunnelConfig.OverrideDNS { if o.tunnelConfig.OverrideDNS {
// Set up DNS override to use our DNS proxy // Set up DNS override to use our DNS proxy
if err := dnsOverride.SetupDNSOverride(o.tunnelConfig.InterfaceName, o.dnsProxy.GetProxyIP()); err != nil { if err := dnsOverride.SetupDNSOverride(o.tunnelConfig.InterfaceName, o.dnsProxy.GetProxyIP()); err != nil {
@@ -274,12 +309,12 @@ func (o *Olm) handleTerminate(msg websocket.WSMessage) {
logger.Error("Error unmarshaling terminate error data: %v", err) logger.Error("Error unmarshaling terminate error data: %v", err)
} else { } else {
logger.Info("Terminate reason (code: %s): %s", errorData.Code, errorData.Message) logger.Info("Terminate reason (code: %s): %s", errorData.Code, errorData.Message)
if errorData.Code == "TERMINATED_INACTIVITY" { if errorData.Code == "TERMINATED_INACTIVITY" {
logger.Info("Ignoring...") logger.Info("Ignoring...")
return return
} }
// Set the olm error in the API server so it can be exposed via status // Set the olm error in the API server so it can be exposed via status
o.apiServer.SetOlmError(errorData.Code, errorData.Message) o.apiServer.SetOlmError(errorData.Code, errorData.Message)
} }

View File

@@ -2,6 +2,7 @@ package olm
import ( import (
"encoding/json" "encoding/json"
"fmt"
"time" "time"
"github.com/fosrl/newt/holepunch" "github.com/fosrl/newt/holepunch"
@@ -220,6 +221,7 @@ func (o *Olm) handleSync(msg websocket.WSMessage) {
logger.Info("Sync: Adding new peer for site %d", siteId) logger.Info("Sync: Adding new peer for site %d", siteId)
o.holePunchManager.TriggerHolePunch() o.holePunchManager.TriggerHolePunch()
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
// // TODO: do we need to send the message to the cloud to add the peer that way? // // TODO: do we need to send the message to the cloud to add the peer that way?
// if err := o.peerManager.AddPeer(expectedSite); err != nil { // if err := o.peerManager.AddPeer(expectedSite); err != nil {
@@ -230,9 +232,17 @@ func (o *Olm) handleSync(msg websocket.WSMessage) {
// add the peer via the server // add the peer via the server
// this is important because newt needs to get triggered as well to add the peer once the hp is complete // this is important because newt needs to get triggered as well to add the peer once the hp is complete
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ chainId := fmt.Sprintf("sync-%d", expectedSite.SiteId)
"siteId": expectedSite.SiteId, o.peerSendMu.Lock()
}, 1*time.Second, 10) if stop, ok := o.stopPeerSends[chainId]; ok {
stop()
}
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
"siteId": expectedSite.SiteId,
"chainId": chainId,
}, 2*time.Second, 10)
o.stopPeerSends[chainId] = stopFunc
o.peerSendMu.Unlock()
} else { } else {
// Existing peer - check if update is needed // Existing peer - check if update is needed

View File

@@ -2,6 +2,8 @@ package olm
import ( import (
"context" "context"
"crypto/rand"
"encoding/hex"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@@ -65,7 +67,10 @@ type Olm struct {
stopRegister func() stopRegister func()
updateRegister func(newData any) updateRegister func(newData any)
stopPeerSend func() stopPeerSends map[string]func()
stopPeerInits map[string]func()
jitPendingSites map[int]string // siteId -> chainId for in-flight JIT requests
peerSendMu sync.Mutex
// WaitGroup to track tunnel lifecycle // WaitGroup to track tunnel lifecycle
tunnelWg sync.WaitGroup tunnelWg sync.WaitGroup
@@ -116,6 +121,13 @@ func (o *Olm) initTunnelInfo(clientID string) error {
return nil return nil
} }
// generateChainId generates a random chain ID for tracking peer sender lifecycles.
func generateChainId() string {
b := make([]byte, 8)
_, _ = rand.Read(b)
return hex.EncodeToString(b)
}
func Init(ctx context.Context, config OlmConfig) (*Olm, error) { func Init(ctx context.Context, config OlmConfig) (*Olm, error) {
logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel)) logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel))
@@ -166,10 +178,13 @@ func Init(ctx context.Context, config OlmConfig) (*Olm, error) {
apiServer.SetAgent(config.Agent) apiServer.SetAgent(config.Agent)
newOlm := &Olm{ newOlm := &Olm{
logFile: logFile, logFile: logFile,
olmCtx: ctx, olmCtx: ctx,
apiServer: apiServer, apiServer: apiServer,
olmConfig: config, olmConfig: config,
stopPeerSends: make(map[string]func()),
stopPeerInits: make(map[string]func()),
jitPendingSites: make(map[int]string),
} }
newOlm.registerAPICallbacks() newOlm.registerAPICallbacks()
@@ -284,6 +299,21 @@ func (o *Olm) registerAPICallbacks() {
logger.Info("Processing power mode change request via API: mode=%s", req.Mode) logger.Info("Processing power mode change request via API: mode=%s", req.Mode)
return o.SetPowerMode(req.Mode) return o.SetPowerMode(req.Mode)
}, },
func(req api.JITConnectionRequest) error {
logger.Info("Processing JIT connect request via API: site=%s resource=%s", req.Site, req.Resource)
chainId := generateChainId()
o.peerSendMu.Lock()
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/init", map[string]interface{}{
"siteId": req.Site,
"resourceId": req.Resource,
"chainId": chainId,
}, 2*time.Second, 10)
o.stopPeerInits[chainId] = stopFunc
o.peerSendMu.Unlock()
return nil
},
) )
} }
@@ -385,6 +415,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
// Handler for peer handshake - adds exit node to holepunch rotation and notifies server // Handler for peer handshake - adds exit node to holepunch rotation and notifies server
o.websocket.RegisterHandler("olm/wg/peer/holepunch/site/add", o.handleWgPeerHolepunchAddSite) o.websocket.RegisterHandler("olm/wg/peer/holepunch/site/add", o.handleWgPeerHolepunchAddSite)
o.websocket.RegisterHandler("olm/wg/peer/chain/cancel", o.handleCancelChain)
o.websocket.RegisterHandler("olm/sync", o.handleSync) o.websocket.RegisterHandler("olm/sync", o.handleSync)
o.websocket.OnConnect(func() error { o.websocket.OnConnect(func() error {
@@ -427,7 +458,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
"userToken": userToken, "userToken": userToken,
"fingerprint": o.fingerprint, "fingerprint": o.fingerprint,
"postures": o.postures, "postures": o.postures,
}, 1*time.Second, 10) }, 2*time.Second, 10)
// Invoke onRegistered callback if configured // Invoke onRegistered callback if configured
if o.olmConfig.OnRegistered != nil { if o.olmConfig.OnRegistered != nil {
@@ -524,6 +555,23 @@ func (o *Olm) Close() {
o.stopRegister = nil o.stopRegister = nil
} }
// Stop all pending peer init and send senders before closing websocket
o.peerSendMu.Lock()
for _, stop := range o.stopPeerInits {
if stop != nil {
stop()
}
}
o.stopPeerInits = make(map[string]func())
for _, stop := range o.stopPeerSends {
if stop != nil {
stop()
}
}
o.stopPeerSends = make(map[string]func())
o.jitPendingSites = make(map[int]string)
o.peerSendMu.Unlock()
// send a disconnect message to the cloud to show disconnected // send a disconnect message to the cloud to show disconnected
if o.websocket != nil { if o.websocket != nil {
o.websocket.SendMessage("olm/disconnecting", map[string]any{}) o.websocket.SendMessage("olm/disconnecting", map[string]any{})

View File

@@ -20,31 +20,43 @@ func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) {
return return
} }
if o.stopPeerSend != nil {
o.stopPeerSend()
o.stopPeerSend = nil
}
jsonData, err := json.Marshal(msg.Data) jsonData, err := json.Marshal(msg.Data)
if err != nil { if err != nil {
logger.Error("Error marshaling data: %v", err) logger.Error("Error marshaling data: %v", err)
return return
} }
var siteConfig peers.SiteConfig var siteConfigMsg struct {
if err := json.Unmarshal(jsonData, &siteConfig); err != nil { peers.SiteConfig
ChainId string `json:"chainId"`
}
if err := json.Unmarshal(jsonData, &siteConfigMsg); err != nil {
logger.Error("Error unmarshaling add data: %v", err) logger.Error("Error unmarshaling add data: %v", err)
return return
} }
if siteConfigMsg.ChainId != "" {
o.peerSendMu.Lock()
if stop, ok := o.stopPeerSends[siteConfigMsg.ChainId]; ok {
stop()
delete(o.stopPeerSends, siteConfigMsg.ChainId)
}
o.peerSendMu.Unlock()
}
if siteConfigMsg.PublicKey == "" {
logger.Warn("Skipping add-peer for site %d (%s): no public key available (site may not be connected)", siteConfigMsg.SiteId, siteConfigMsg.Name)
return
}
_ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it _ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it
if err := o.peerManager.AddPeer(siteConfig); err != nil { if err := o.peerManager.AddPeer(siteConfigMsg.SiteConfig); err != nil {
logger.Error("Failed to add peer: %v", err) logger.Error("Failed to add peer: %v", err)
return return
} }
logger.Info("Successfully added peer for site %d", siteConfig.SiteId) logger.Info("Successfully added peer for site %d", siteConfigMsg.SiteId)
} }
func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) { func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) {
@@ -164,13 +176,21 @@ func (o *Olm) handleWgPeerRelay(msg websocket.WSMessage) {
return return
} }
var relayData peers.RelayPeerData var relayData struct {
peers.RelayPeerData
ChainId string `json:"chainId"`
}
if err := json.Unmarshal(jsonData, &relayData); err != nil { if err := json.Unmarshal(jsonData, &relayData); err != nil {
logger.Error("Error unmarshaling relay data: %v", err) logger.Error("Error unmarshaling relay data: %v", err)
return return
} }
if monitor := o.peerManager.GetPeerMonitor(); monitor != nil {
monitor.CancelRelaySend(relayData.ChainId)
}
primaryRelay, err := util.ResolveDomainUpstream(relayData.RelayEndpoint, o.tunnelConfig.PublicDNS) primaryRelay, err := util.ResolveDomainUpstream(relayData.RelayEndpoint, o.tunnelConfig.PublicDNS)
if err != nil { if err != nil {
logger.Error("Failed to resolve primary relay endpoint: %v", err) logger.Error("Failed to resolve primary relay endpoint: %v", err)
return return
@@ -197,13 +217,21 @@ func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) {
return return
} }
var relayData peers.UnRelayPeerData var relayData struct {
peers.UnRelayPeerData
ChainId string `json:"chainId"`
}
if err := json.Unmarshal(jsonData, &relayData); err != nil { if err := json.Unmarshal(jsonData, &relayData); err != nil {
logger.Error("Error unmarshaling relay data: %v", err) logger.Error("Error unmarshaling relay data: %v", err)
return return
} }
if monitor := o.peerManager.GetPeerMonitor(); monitor != nil {
monitor.CancelRelaySend(relayData.ChainId)
}
primaryRelay, err := util.ResolveDomainUpstream(relayData.Endpoint, o.tunnelConfig.PublicDNS) primaryRelay, err := util.ResolveDomainUpstream(relayData.Endpoint, o.tunnelConfig.PublicDNS)
if err != nil { if err != nil {
logger.Warn("Failed to resolve primary relay endpoint: %v", err) logger.Warn("Failed to resolve primary relay endpoint: %v", err)
} }
@@ -230,7 +258,8 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
} }
var handshakeData struct { var handshakeData struct {
SiteId int `json:"siteId"` SiteId int `json:"siteId"`
ChainId string `json:"chainId"`
ExitNode struct { ExitNode struct {
PublicKey string `json:"publicKey"` PublicKey string `json:"publicKey"`
Endpoint string `json:"endpoint"` Endpoint string `json:"endpoint"`
@@ -243,6 +272,19 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
return return
} }
// Stop the peer init sender for this chain, if any
if handshakeData.ChainId != "" {
o.peerSendMu.Lock()
if stop, ok := o.stopPeerInits[handshakeData.ChainId]; ok {
stop()
delete(o.stopPeerInits, handshakeData.ChainId)
}
// If this chain was initiated by a DNS-triggered JIT request, clear the
// pending entry so the site can be re-triggered if needed in the future.
delete(o.jitPendingSites, handshakeData.SiteId)
o.peerSendMu.Unlock()
}
// Get existing peer from PeerManager // Get existing peer from PeerManager
_, exists := o.peerManager.GetPeer(handshakeData.SiteId) _, exists := o.peerManager.GetPeer(handshakeData.SiteId)
if exists { if exists {
@@ -273,10 +315,72 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
// Send handshake acknowledgment back to server with retry // Send handshake acknowledgment back to server with retry, keyed by chainId
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ chainId := handshakeData.ChainId
"siteId": handshakeData.SiteId, if chainId == "" {
}, 1*time.Second, 10) chainId = generateChainId()
}
o.peerSendMu.Lock()
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
"siteId": handshakeData.SiteId,
"chainId": chainId,
}, 2*time.Second, 10)
o.stopPeerSends[chainId] = stopFunc
o.peerSendMu.Unlock()
logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint) logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint)
} }
func (o *Olm) handleCancelChain(msg websocket.WSMessage) {
logger.Debug("Received cancel-chain message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling cancel-chain data: %v", err)
return
}
var cancelData struct {
ChainId string `json:"chainId"`
}
if err := json.Unmarshal(jsonData, &cancelData); err != nil {
logger.Error("Error unmarshaling cancel-chain data: %v", err)
return
}
if cancelData.ChainId == "" {
logger.Warn("Received cancel-chain message with no chainId")
return
}
o.peerSendMu.Lock()
defer o.peerSendMu.Unlock()
found := false
if stop, ok := o.stopPeerInits[cancelData.ChainId]; ok {
stop()
delete(o.stopPeerInits, cancelData.ChainId)
found = true
}
// If this chain was a DNS-triggered JIT request, clear the pending entry so
// the site can be re-triggered on the next DNS lookup.
for siteId, chainId := range o.jitPendingSites {
if chainId == cancelData.ChainId {
delete(o.jitPendingSites, siteId)
break
}
}
if stop, ok := o.stopPeerSends[cancelData.ChainId]; ok {
stop()
delete(o.stopPeerSends, cancelData.ChainId)
found = true
}
if found {
logger.Info("Cancelled chain %s", cancelData.ChainId)
} else {
logger.Warn("Cancel-chain: no active sender found for chain %s", cancelData.ChainId)
}
}

View File

@@ -110,6 +110,19 @@ func (pm *PeerManager) GetAllPeers() []SiteConfig {
func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error { func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
pm.mu.Lock() pm.mu.Lock()
defer pm.mu.Unlock() defer pm.mu.Unlock()
for _, alias := range siteConfig.Aliases {
address := net.ParseIP(alias.AliasAddress)
if address == nil {
continue
}
pm.dnsProxy.AddDNSRecord(alias.Alias, address, siteConfig.SiteId)
}
if siteConfig.PublicKey == "" {
logger.Debug("Skip adding site %d because no pub key", siteConfig.SiteId)
return nil
}
// build the allowed IPs list from the remote subnets and aliases and add them to the peer // build the allowed IPs list from the remote subnets and aliases and add them to the peer
allowedIPs := make([]string, 0, len(siteConfig.RemoteSubnets)+len(siteConfig.Aliases)) allowedIPs := make([]string, 0, len(siteConfig.RemoteSubnets)+len(siteConfig.Aliases))
@@ -143,14 +156,7 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
if err := network.AddRoutes(siteConfig.RemoteSubnets, pm.interfaceName); err != nil { if err := network.AddRoutes(siteConfig.RemoteSubnets, pm.interfaceName); err != nil {
logger.Error("Failed to add routes for remote subnets: %v", err) logger.Error("Failed to add routes for remote subnets: %v", err)
} }
for _, alias := range siteConfig.Aliases {
address := net.ParseIP(alias.AliasAddress)
if address == nil {
continue
}
pm.dnsProxy.AddDNSRecord(alias.Alias, address)
}
monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0] monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0]
monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port
@@ -437,7 +443,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
if address == nil { if address == nil {
continue continue
} }
pm.dnsProxy.AddDNSRecord(alias.Alias, address) pm.dnsProxy.AddDNSRecord(alias.Alias, address, siteConfig.SiteId)
} }
pm.peerMonitor.UpdateHolepunchEndpoint(siteConfig.SiteId, siteConfig.Endpoint) pm.peerMonitor.UpdateHolepunchEndpoint(siteConfig.SiteId, siteConfig.Endpoint)
@@ -717,7 +723,7 @@ func (pm *PeerManager) AddAlias(siteId int, alias Alias) error {
address := net.ParseIP(alias.AliasAddress) address := net.ParseIP(alias.AliasAddress)
if address != nil { if address != nil {
pm.dnsProxy.AddDNSRecord(alias.Alias, address) pm.dnsProxy.AddDNSRecord(alias.Alias, address, siteId)
} }
// Add an allowed IP for the alias // Add an allowed IP for the alias

View File

@@ -2,6 +2,8 @@ package monitor
import ( import (
"context" "context"
"crypto/rand"
"encoding/hex"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@@ -31,11 +33,15 @@ type PeerMonitor struct {
monitors map[int]*Client monitors map[int]*Client
mutex sync.Mutex mutex sync.Mutex
running bool running bool
timeout time.Duration timeout time.Duration
maxAttempts int maxAttempts int
wsClient *websocket.Client wsClient *websocket.Client
publicDNS []string publicDNS []string
// Relay sender tracking
relaySends map[string]func()
relaySendMu sync.Mutex
// Netstack fields // Netstack fields
middleDev *middleDevice.MiddleDevice middleDev *middleDevice.MiddleDevice
localIP string localIP string
@@ -48,13 +54,13 @@ type PeerMonitor struct {
nsWg sync.WaitGroup nsWg sync.WaitGroup
// Holepunch testing fields // Holepunch testing fields
sharedBind *bind.SharedBind sharedBind *bind.SharedBind
holepunchTester *holepunch.HolepunchTester holepunchTester *holepunch.HolepunchTester
holepunchTimeout time.Duration holepunchTimeout time.Duration
holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing
holepunchStatus map[int]bool // siteID -> connected status holepunchStatus map[int]bool // siteID -> connected status
holepunchStopChan chan struct{} holepunchStopChan chan struct{}
holepunchUpdateChan chan struct{} holepunchUpdateChan chan struct{}
// Relay tracking fields // Relay tracking fields
relayedPeers map[int]bool // siteID -> whether the peer is currently relayed relayedPeers map[int]bool // siteID -> whether the peer is currently relayed
@@ -83,6 +89,12 @@ type PeerMonitor struct {
} }
// NewPeerMonitor creates a new peer monitor with the given callback // NewPeerMonitor creates a new peer monitor with the given callback
func generateChainId() string {
b := make([]byte, 8)
_, _ = rand.Read(b)
return hex.EncodeToString(b)
}
func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind, apiServer *api.API, publicDNS []string) *PeerMonitor { func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind, apiServer *api.API, publicDNS []string) *PeerMonitor {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
pm := &PeerMonitor{ pm := &PeerMonitor{
@@ -101,6 +113,7 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
holepunchEndpoints: make(map[int]string), holepunchEndpoints: make(map[int]string),
holepunchStatus: make(map[int]bool), holepunchStatus: make(map[int]bool),
relayedPeers: make(map[int]bool), relayedPeers: make(map[int]bool),
relaySends: make(map[string]func()),
holepunchMaxAttempts: 3, // Trigger relay after 3 consecutive failures holepunchMaxAttempts: 3, // Trigger relay after 3 consecutive failures
holepunchFailures: make(map[int]int), holepunchFailures: make(map[int]int),
// Rapid initial test settings: complete within ~1.5 seconds // Rapid initial test settings: complete within ~1.5 seconds
@@ -398,20 +411,23 @@ func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status Connectio
} }
} }
// sendRelay sends a relay message to the server // sendRelay sends a relay message to the server with retry, keyed by chainId
func (pm *PeerMonitor) sendRelay(siteID int) error { func (pm *PeerMonitor) sendRelay(siteID int) error {
if pm.wsClient == nil { if pm.wsClient == nil {
return fmt.Errorf("websocket client is nil") return fmt.Errorf("websocket client is nil")
} }
err := pm.wsClient.SendMessage("olm/wg/relay", map[string]interface{}{ chainId := generateChainId()
"siteId": siteID, stopFunc, _ := pm.wsClient.SendMessageInterval("olm/wg/relay", map[string]interface{}{
}) "siteId": siteID,
if err != nil { "chainId": chainId,
logger.Error("Failed to send registration message: %v", err) }, 2*time.Second, 10)
return err
} pm.relaySendMu.Lock()
logger.Info("Sent relay message") pm.relaySends[chainId] = stopFunc
pm.relaySendMu.Unlock()
logger.Info("Sent relay message for site %d (chain %s)", siteID, chainId)
return nil return nil
} }
@@ -421,23 +437,52 @@ func (pm *PeerMonitor) RequestRelay(siteID int) error {
return pm.sendRelay(siteID) return pm.sendRelay(siteID)
} }
// sendUnRelay sends an unrelay message to the server // sendUnRelay sends an unrelay message to the server with retry, keyed by chainId
func (pm *PeerMonitor) sendUnRelay(siteID int) error { func (pm *PeerMonitor) sendUnRelay(siteID int) error {
if pm.wsClient == nil { if pm.wsClient == nil {
return fmt.Errorf("websocket client is nil") return fmt.Errorf("websocket client is nil")
} }
err := pm.wsClient.SendMessage("olm/wg/unrelay", map[string]interface{}{ chainId := generateChainId()
"siteId": siteID, stopFunc, _ := pm.wsClient.SendMessageInterval("olm/wg/unrelay", map[string]interface{}{
}) "siteId": siteID,
if err != nil { "chainId": chainId,
logger.Error("Failed to send registration message: %v", err) }, 2*time.Second, 10)
return err
} pm.relaySendMu.Lock()
logger.Info("Sent unrelay message") pm.relaySends[chainId] = stopFunc
pm.relaySendMu.Unlock()
logger.Info("Sent unrelay message for site %d (chain %s)", siteID, chainId)
return nil return nil
} }
// CancelRelaySend stops the interval sender for the given chainId, if one exists.
// If chainId is empty, all active relay senders are stopped.
func (pm *PeerMonitor) CancelRelaySend(chainId string) {
pm.relaySendMu.Lock()
defer pm.relaySendMu.Unlock()
if chainId == "" {
for id, stop := range pm.relaySends {
if stop != nil {
stop()
}
delete(pm.relaySends, id)
}
logger.Info("Cancelled all relay senders")
return
}
if stop, ok := pm.relaySends[chainId]; ok {
stop()
delete(pm.relaySends, chainId)
logger.Info("Cancelled relay sender for chain %s", chainId)
} else {
logger.Warn("CancelRelaySend: no active sender for chain %s", chainId)
}
}
// Stop stops monitoring all peers // Stop stops monitoring all peers
func (pm *PeerMonitor) Stop() { func (pm *PeerMonitor) Stop() {
// Stop holepunch monitor first (outside of mutex to avoid deadlock) // Stop holepunch monitor first (outside of mutex to avoid deadlock)
@@ -536,7 +581,7 @@ func (pm *PeerMonitor) runHolepunchMonitor() {
pm.holepunchCurrentInterval = pm.holepunchMinInterval pm.holepunchCurrentInterval = pm.holepunchMinInterval
currentInterval := pm.holepunchCurrentInterval currentInterval := pm.holepunchCurrentInterval
pm.mutex.Unlock() pm.mutex.Unlock()
timer.Reset(currentInterval) timer.Reset(currentInterval)
logger.Debug("Holepunch monitor interval updated, reset to %v", currentInterval) logger.Debug("Holepunch monitor interval updated, reset to %v", currentInterval)
case <-timer.C: case <-timer.C:
@@ -679,6 +724,16 @@ func (pm *PeerMonitor) Close() {
// Stop holepunch monitor first (outside of mutex to avoid deadlock) // Stop holepunch monitor first (outside of mutex to avoid deadlock)
pm.stopHolepunchMonitor() pm.stopHolepunchMonitor()
// Stop all pending relay senders
pm.relaySendMu.Lock()
for chainId, stop := range pm.relaySends {
if stop != nil {
stop()
}
delete(pm.relaySends, chainId)
}
pm.relaySendMu.Unlock()
pm.mutex.Lock() pm.mutex.Lock()
defer pm.mutex.Unlock() defer pm.mutex.Unlock()