diff --git a/api/api.go b/api/api.go index 047ce08..895140b 100644 --- a/api/api.go +++ b/api/api.go @@ -78,6 +78,13 @@ type MetadataChangeRequest struct { Postures map[string]any `json:"postures"` } +// JITConnectionRequest defines the structure for a dynamic Just-In-Time connection request. +// Either SiteID or ResourceID must be provided (but not necessarily both). +type JITConnectionRequest struct { + Site string `json:"site,omitempty"` + Resource string `json:"resource,omitempty"` +} + // API represents the HTTP server and its state type API struct { addr string @@ -92,6 +99,7 @@ type API struct { onExit func() error onRebind func() error onPowerMode func(PowerModeRequest) error + onJITConnect func(JITConnectionRequest) error statusMu sync.RWMutex peerStatuses map[int]*PeerStatus @@ -143,6 +151,7 @@ func (s *API) SetHandlers( onExit func() error, onRebind func() error, onPowerMode func(PowerModeRequest) error, + onJITConnect func(JITConnectionRequest) error, ) { s.onConnect = onConnect s.onSwitchOrg = onSwitchOrg @@ -151,6 +160,7 @@ func (s *API) SetHandlers( s.onExit = onExit s.onRebind = onRebind s.onPowerMode = onPowerMode + s.onJITConnect = onJITConnect } // Start starts the HTTP server @@ -169,6 +179,7 @@ func (s *API) Start() error { mux.HandleFunc("/health", s.handleHealth) mux.HandleFunc("/rebind", s.handleRebind) mux.HandleFunc("/power-mode", s.handlePowerMode) + mux.HandleFunc("/jit-connect", s.handleJITConnect) s.server = &http.Server{ Handler: mux, @@ -633,6 +644,54 @@ func (s *API) handleRebind(w http.ResponseWriter, r *http.Request) { }) } +// handleJITConnect handles the /jit-connect endpoint. +// It initiates a dynamic Just-In-Time connection to a site identified by either +// a site or a resource. Exactly one of the two must be provided. +func (s *API) handleJITConnect(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var req JITConnectionRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest) + return + } + + // Validate that exactly one of site or resource is provided + if req.Site == "" && req.Resource == "" { + http.Error(w, "Missing required field: either site or resource must be provided", http.StatusBadRequest) + return + } + if req.Site != "" && req.Resource != "" { + http.Error(w, "Ambiguous request: provide either site or resource, not both", http.StatusBadRequest) + return + } + + if req.Site != "" { + logger.Info("Received JIT connection request via API: site=%s", req.Site) + } else { + logger.Info("Received JIT connection request via API: resource=%s", req.Resource) + } + + if s.onJITConnect != nil { + if err := s.onJITConnect(req); err != nil { + http.Error(w, fmt.Sprintf("JIT connection failed: %v", err), http.StatusInternalServerError) + return + } + } else { + http.Error(w, "JIT connect handler not configured", http.StatusNotImplemented) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusAccepted) + _ = json.NewEncoder(w).Encode(map[string]string{ + "status": "JIT connection request accepted", + }) +} + // handlePowerMode handles the /power-mode endpoint // This allows changing the power mode between "normal" and "low" func (s *API) handlePowerMode(w http.ResponseWriter, r *http.Request) { diff --git a/dns/dns_proxy.go b/dns/dns_proxy.go index 986e847..9451ba8 100644 --- a/dns/dns_proxy.go +++ b/dns/dns_proxy.go @@ -45,6 +45,11 @@ type DNSProxy struct { tunnelActivePorts map[uint16]bool tunnelPortsLock sync.Mutex + // jitHandler is called when a local record is resolved for a site that may not be + // connected yet, giving the caller a chance to initiate a JIT connection. + // It is invoked asynchronously so it never blocks DNS resolution. + jitHandler func(siteId int) + ctx context.Context cancel context.CancelFunc wg sync.WaitGroup @@ -384,6 +389,16 @@ func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clie response = p.checkLocalRecords(msg, question) } + // If a local A/AAAA record was found, notify the JIT handler so that the owning + // site can be connected on-demand if it is not yet active. + if response != nil && p.jitHandler != nil && + (question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA) { + if siteId, ok := p.recordStore.GetSiteIdForDomain(question.Name); ok && siteId != 0 { + handler := p.jitHandler + go handler(siteId) + } + } + // If no local records, forward to upstream if response == nil { logger.Debug("No local record for %s, forwarding upstream to %v", question.Name, p.upstreamDNS) @@ -447,19 +462,20 @@ func (p *DNSProxy) checkLocalRecords(query *dns.Msg, question dns.Question) *dns return nil } - ips := p.recordStore.GetRecords(question.Name, recordType) - if len(ips) == 0 { + ips, exists := p.recordStore.GetRecords(question.Name, recordType) + if !exists { + // Domain not found in local records, forward to upstream return nil } logger.Debug("Found %d local record(s) for %s", len(ips), question.Name) - // Create response message + // Create response message (NODATA if no records, otherwise with answers) response := new(dns.Msg) response.SetReply(query) response.Authoritative = true - // Add answer records + // Add answer records (loop is a no-op if ips is empty) for _, ip := range ips { var rr dns.RR if question.Qtype == dns.TypeA { @@ -717,11 +733,20 @@ func (p *DNSProxy) runPacketSender() { } } +// SetJITHandler registers a callback that is invoked whenever a local DNS record is +// resolved for an A or AAAA query. The siteId identifies which site owns the record. +// The handler is called in its own goroutine so it must be safe to call concurrently. +// Pass nil to disable JIT notifications. +func (p *DNSProxy) SetJITHandler(handler func(siteId int)) { + p.jitHandler = handler +} + // AddDNSRecord adds a DNS record to the local store // domain should be a domain name (e.g., "example.com" or "example.com.") // ip should be a valid IPv4 or IPv6 address -func (p *DNSProxy) AddDNSRecord(domain string, ip net.IP) error { - return p.recordStore.AddRecord(domain, ip) +func (p *DNSProxy) AddDNSRecord(domain string, ip net.IP, siteId int) error { + logger.Debug("Adding dns record for domain %s with IP %s (siteId=%d)", domain, ip.String(), siteId) + return p.recordStore.AddRecord(domain, ip, siteId) } // RemoveDNSRecord removes a DNS record from the local store @@ -730,8 +755,9 @@ func (p *DNSProxy) RemoveDNSRecord(domain string, ip net.IP) { p.recordStore.RemoveRecord(domain, ip) } -// GetDNSRecords returns all IP addresses for a domain and record type -func (p *DNSProxy) GetDNSRecords(domain string, recordType RecordType) []net.IP { +// GetDNSRecords returns all IP addresses for a domain and record type. +// The second return value indicates whether the domain exists. +func (p *DNSProxy) GetDNSRecords(domain string, recordType RecordType) ([]net.IP, bool) { return p.recordStore.GetRecords(domain, recordType) } diff --git a/dns/dns_proxy_test.go b/dns/dns_proxy_test.go new file mode 100644 index 0000000..9eecad7 --- /dev/null +++ b/dns/dns_proxy_test.go @@ -0,0 +1,178 @@ +package dns + +import ( + "net" + "testing" + + "github.com/miekg/dns" +) + +func TestCheckLocalRecordsNODATAForAAAA(t *testing.T) { + proxy := &DNSProxy{ + recordStore: NewDNSRecordStore(), + } + + // Add an A record for a domain (no AAAA record) + ip := net.ParseIP("10.0.0.1") + err := proxy.recordStore.AddRecord("myservice.internal", ip, 0) + if err != nil { + t.Fatalf("Failed to add A record: %v", err) + } + + // Query AAAA for domain with only A record - should return NODATA + query := new(dns.Msg) + query.SetQuestion("myservice.internal.", dns.TypeAAAA) + response := proxy.checkLocalRecords(query, query.Question[0]) + + if response == nil { + t.Fatal("Expected NODATA response, got nil (would forward to upstream)") + } + if response.Rcode != dns.RcodeSuccess { + t.Errorf("Expected Rcode NOERROR (0), got %d", response.Rcode) + } + if len(response.Answer) != 0 { + t.Errorf("Expected empty answer section for NODATA, got %d answers", len(response.Answer)) + } + if !response.Authoritative { + t.Error("Expected response to be authoritative") + } + + // Query A for same domain - should return the record + query = new(dns.Msg) + query.SetQuestion("myservice.internal.", dns.TypeA) + response = proxy.checkLocalRecords(query, query.Question[0]) + + if response == nil { + t.Fatal("Expected response with A record, got nil") + } + if len(response.Answer) != 1 { + t.Fatalf("Expected 1 answer, got %d", len(response.Answer)) + } + aRecord, ok := response.Answer[0].(*dns.A) + if !ok { + t.Fatal("Expected A record in answer") + } + if !aRecord.A.Equal(ip.To4()) { + t.Errorf("Expected IP %v, got %v", ip.To4(), aRecord.A) + } +} + +func TestCheckLocalRecordsNODATAForA(t *testing.T) { + proxy := &DNSProxy{ + recordStore: NewDNSRecordStore(), + } + + // Add an AAAA record for a domain (no A record) + ip := net.ParseIP("2001:db8::1") + err := proxy.recordStore.AddRecord("ipv6only.internal", ip, 0) + if err != nil { + t.Fatalf("Failed to add AAAA record: %v", err) + } + + // Query A for domain with only AAAA record - should return NODATA + query := new(dns.Msg) + query.SetQuestion("ipv6only.internal.", dns.TypeA) + response := proxy.checkLocalRecords(query, query.Question[0]) + + if response == nil { + t.Fatal("Expected NODATA response, got nil") + } + if response.Rcode != dns.RcodeSuccess { + t.Errorf("Expected Rcode NOERROR (0), got %d", response.Rcode) + } + if len(response.Answer) != 0 { + t.Errorf("Expected empty answer section, got %d answers", len(response.Answer)) + } + if !response.Authoritative { + t.Error("Expected response to be authoritative") + } + + // Query AAAA for same domain - should return the record + query = new(dns.Msg) + query.SetQuestion("ipv6only.internal.", dns.TypeAAAA) + response = proxy.checkLocalRecords(query, query.Question[0]) + + if response == nil { + t.Fatal("Expected response with AAAA record, got nil") + } + if len(response.Answer) != 1 { + t.Fatalf("Expected 1 answer, got %d", len(response.Answer)) + } + aaaaRecord, ok := response.Answer[0].(*dns.AAAA) + if !ok { + t.Fatal("Expected AAAA record in answer") + } + if !aaaaRecord.AAAA.Equal(ip) { + t.Errorf("Expected IP %v, got %v", ip, aaaaRecord.AAAA) + } +} + +func TestCheckLocalRecordsNonExistentDomain(t *testing.T) { + proxy := &DNSProxy{ + recordStore: NewDNSRecordStore(), + } + + // Add a record so the store isn't empty + err := proxy.recordStore.AddRecord("exists.internal", net.ParseIP("10.0.0.1"), 0) + if err != nil { + t.Fatalf("Failed to add record: %v", err) + } + + // Query A for non-existent domain - should return nil (forward to upstream) + query := new(dns.Msg) + query.SetQuestion("unknown.internal.", dns.TypeA) + response := proxy.checkLocalRecords(query, query.Question[0]) + + if response != nil { + t.Error("Expected nil for non-existent domain, got response") + } + + // Query AAAA for non-existent domain - should also return nil + query = new(dns.Msg) + query.SetQuestion("unknown.internal.", dns.TypeAAAA) + response = proxy.checkLocalRecords(query, query.Question[0]) + + if response != nil { + t.Error("Expected nil for non-existent domain, got response") + } +} + +func TestCheckLocalRecordsNODATAWildcard(t *testing.T) { + proxy := &DNSProxy{ + recordStore: NewDNSRecordStore(), + } + + // Add a wildcard A record (no AAAA) + ip := net.ParseIP("10.0.0.1") + err := proxy.recordStore.AddRecord("*.wildcard.internal", ip, 0) + if err != nil { + t.Fatalf("Failed to add wildcard A record: %v", err) + } + + // Query AAAA for wildcard-matched domain - should return NODATA + query := new(dns.Msg) + query.SetQuestion("host.wildcard.internal.", dns.TypeAAAA) + response := proxy.checkLocalRecords(query, query.Question[0]) + + if response == nil { + t.Fatal("Expected NODATA response for wildcard match, got nil") + } + if response.Rcode != dns.RcodeSuccess { + t.Errorf("Expected Rcode NOERROR (0), got %d", response.Rcode) + } + if len(response.Answer) != 0 { + t.Errorf("Expected empty answer section, got %d answers", len(response.Answer)) + } + + // Query A for wildcard-matched domain - should return the record + query = new(dns.Msg) + query.SetQuestion("host.wildcard.internal.", dns.TypeA) + response = proxy.checkLocalRecords(query, query.Question[0]) + + if response == nil { + t.Fatal("Expected response with A record, got nil") + } + if len(response.Answer) != 1 { + t.Fatalf("Expected 1 answer, got %d", len(response.Answer)) + } +} diff --git a/dns/dns_records.go b/dns/dns_records.go index 199b94b..270bae6 100644 --- a/dns/dns_records.go +++ b/dns/dns_records.go @@ -18,24 +18,28 @@ const ( RecordTypePTR RecordType = RecordType(dns.TypePTR) ) -// DNSRecordStore manages local DNS records for A, AAAA, and PTR queries +// recordSet holds A and AAAA records for a single domain or wildcard pattern +type recordSet struct { + A []net.IP + AAAA []net.IP + SiteId int +} + +// DNSRecordStore manages local DNS records for A, AAAA, and PTR queries. +// Exact domains are stored in a map; wildcard patterns are in a separate map. type DNSRecordStore struct { - mu sync.RWMutex - aRecords map[string][]net.IP // domain -> list of IPv4 addresses - aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses - aWildcards map[string][]net.IP // wildcard pattern -> list of IPv4 addresses - aaaaWildcards map[string][]net.IP // wildcard pattern -> list of IPv6 addresses - ptrRecords map[string]string // IP address string -> domain name + mu sync.RWMutex + exact map[string]*recordSet // normalized FQDN -> A/AAAA records + wildcards map[string]*recordSet // wildcard pattern -> A/AAAA records + ptrRecords map[string]string // IP address string -> domain name } // NewDNSRecordStore creates a new DNS record store func NewDNSRecordStore() *DNSRecordStore { return &DNSRecordStore{ - aRecords: make(map[string][]net.IP), - aaaaRecords: make(map[string][]net.IP), - aWildcards: make(map[string][]net.IP), - aaaaWildcards: make(map[string][]net.IP), - ptrRecords: make(map[string]string), + exact: make(map[string]*recordSet), + wildcards: make(map[string]*recordSet), + ptrRecords: make(map[string]string), } } @@ -43,47 +47,57 @@ func NewDNSRecordStore() *DNSRecordStore { // domain should be in FQDN format (e.g., "example.com.") // domain can contain wildcards: * (0+ chars) and ? (exactly 1 char) // ip should be a valid IPv4 or IPv6 address +// siteId is the site that owns this alias/domain // Automatically adds a corresponding PTR record for non-wildcard domains -func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error { +func (s *DNSRecordStore) AddRecord(domain string, ip net.IP, siteId int) error { s.mu.Lock() defer s.mu.Unlock() - // Ensure domain ends with a dot (FQDN format) if len(domain) == 0 || domain[len(domain)-1] != '.' { domain = domain + "." } - - // Normalize domain to lowercase FQDN domain = strings.ToLower(dns.Fqdn(domain)) - - // Check if domain contains wildcards isWildcard := strings.ContainsAny(domain, "*?") - if ip.To4() != nil { - // IPv4 address - if isWildcard { - s.aWildcards[domain] = append(s.aWildcards[domain], ip) - } else { - s.aRecords[domain] = append(s.aRecords[domain], ip) - // Automatically add PTR record for non-wildcard domains - s.ptrRecords[ip.String()] = domain - } - } else if ip.To16() != nil { - // IPv6 address - if isWildcard { - s.aaaaWildcards[domain] = append(s.aaaaWildcards[domain], ip) - } else { - s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip) - // Automatically add PTR record for non-wildcard domains - s.ptrRecords[ip.String()] = domain - } - } else { + isV4 := ip.To4() != nil + if !isV4 && ip.To16() == nil { return &net.ParseError{Type: "IP address", Text: ip.String()} } + // Choose the appropriate map based on whether this is a wildcard + m := s.exact + if isWildcard { + m = s.wildcards + } + + if m[domain] == nil { + m[domain] = &recordSet{SiteId: siteId} + } + rs := m[domain] + if isV4 { + for _, existing := range rs.A { + if existing.Equal(ip) { + return nil + } + } + rs.A = append(rs.A, ip) + } else { + for _, existing := range rs.AAAA { + if existing.Equal(ip) { + return nil + } + } + rs.AAAA = append(rs.AAAA, ip) + } + + // Add PTR record for non-wildcard domains + if !isWildcard { + s.ptrRecords[ip.String()] = domain + } return nil } + // AddPTRRecord adds a PTR record mapping an IP address to a domain name // ip should be a valid IPv4 or IPv6 address // domain should be in FQDN format (e.g., "example.com.") @@ -112,89 +126,62 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) { s.mu.Lock() defer s.mu.Unlock() - // Ensure domain ends with a dot (FQDN format) if len(domain) == 0 || domain[len(domain)-1] != '.' { domain = domain + "." } - - // Normalize domain to lowercase FQDN domain = strings.ToLower(dns.Fqdn(domain)) - - // Check if domain contains wildcards isWildcard := strings.ContainsAny(domain, "*?") - if ip == nil { - // Remove all records for this domain - if isWildcard { - delete(s.aWildcards, domain) - delete(s.aaaaWildcards, domain) - } else { - // For non-wildcard domains, remove PTR records for all IPs - if ips, ok := s.aRecords[domain]; ok { - for _, ipAddr := range ips { - // Only remove PTR if it points to this domain - if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain { - delete(s.ptrRecords, ipAddr.String()) - } - } - } - if ips, ok := s.aaaaRecords[domain]; ok { - for _, ipAddr := range ips { - // Only remove PTR if it points to this domain - if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain { - delete(s.ptrRecords, ipAddr.String()) - } - } - } - delete(s.aRecords, domain) - delete(s.aaaaRecords, domain) - } + // Choose the appropriate map + m := s.exact + if isWildcard { + m = s.wildcards + } + + rs := m[domain] + if rs == nil { return } + if ip == nil { + // Remove all records for this domain + if !isWildcard { + for _, ipAddr := range rs.A { + if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain { + delete(s.ptrRecords, ipAddr.String()) + } + } + for _, ipAddr := range rs.AAAA { + if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain { + delete(s.ptrRecords, ipAddr.String()) + } + } + } + delete(m, domain) + return + } + + // Remove specific IP if ip.To4() != nil { - // Remove specific IPv4 address - if isWildcard { - if ips, ok := s.aWildcards[domain]; ok { - s.aWildcards[domain] = removeIP(ips, ip) - if len(s.aWildcards[domain]) == 0 { - delete(s.aWildcards, domain) - } - } - } else { - if ips, ok := s.aRecords[domain]; ok { - s.aRecords[domain] = removeIP(ips, ip) - if len(s.aRecords[domain]) == 0 { - delete(s.aRecords, domain) - } - // Automatically remove PTR record if it points to this domain - if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain { - delete(s.ptrRecords, ip.String()) - } + rs.A = removeIP(rs.A, ip) + if !isWildcard { + if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain { + delete(s.ptrRecords, ip.String()) } } - } else if ip.To16() != nil { - // Remove specific IPv6 address - if isWildcard { - if ips, ok := s.aaaaWildcards[domain]; ok { - s.aaaaWildcards[domain] = removeIP(ips, ip) - if len(s.aaaaWildcards[domain]) == 0 { - delete(s.aaaaWildcards, domain) - } - } - } else { - if ips, ok := s.aaaaRecords[domain]; ok { - s.aaaaRecords[domain] = removeIP(ips, ip) - if len(s.aaaaRecords[domain]) == 0 { - delete(s.aaaaRecords, domain) - } - // Automatically remove PTR record if it points to this domain - if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain { - delete(s.ptrRecords, ip.String()) - } + } else { + rs.AAAA = removeIP(rs.AAAA, ip) + if !isWildcard { + if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain { + delete(s.ptrRecords, ip.String()) } } } + + // Clean up empty record sets + if len(rs.A) == 0 && len(rs.AAAA) == 0 { + delete(m, domain) + } } // RemovePTRRecord removes a PTR record for an IP address @@ -205,61 +192,80 @@ func (s *DNSRecordStore) RemovePTRRecord(ip net.IP) { delete(s.ptrRecords, ip.String()) } -// GetRecords returns all IP addresses for a domain and record type -// First checks for exact matches, then checks wildcard patterns -func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP { +// GetSiteIdForDomain returns the siteId associated with the given domain. +// It checks exact matches first, then wildcard patterns. +// The second return value is false if the domain is not found in local records. +func (s *DNSRecordStore) GetSiteIdForDomain(domain string) (int, bool) { s.mu.RLock() defer s.mu.RUnlock() - // Normalize domain to lowercase FQDN domain = strings.ToLower(dns.Fqdn(domain)) - var records []net.IP - switch recordType { - case RecordTypeA: - // Check exact match first - if ips, ok := s.aRecords[domain]; ok { - // Return a copy to prevent external modifications - records = make([]net.IP, len(ips)) - copy(records, ips) - return records - } - // Check wildcard patterns - for pattern, ips := range s.aWildcards { - if matchWildcard(pattern, domain) { - records = append(records, ips...) - } - } - if len(records) > 0 { - // Return a copy - result := make([]net.IP, len(records)) - copy(result, records) - return result - } + // Check exact match first + if rs, exists := s.exact[domain]; exists { + return rs.SiteId, true + } - case RecordTypeAAAA: - // Check exact match first - if ips, ok := s.aaaaRecords[domain]; ok { - // Return a copy to prevent external modifications - records = make([]net.IP, len(ips)) - copy(records, ips) - return records - } - // Check wildcard patterns - for pattern, ips := range s.aaaaWildcards { - if matchWildcard(pattern, domain) { - records = append(records, ips...) - } - } - if len(records) > 0 { - // Return a copy - result := make([]net.IP, len(records)) - copy(result, records) - return result + // Check wildcard matches + for pattern, rs := range s.wildcards { + if matchWildcard(pattern, domain) { + return rs.SiteId, true } } - return records + return 0, false +} + +// GetRecords returns all IP addresses for a domain and record type. +// The second return value indicates whether the domain exists at all +// (true = domain exists, use NODATA if no records; false = NXDOMAIN). +func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) ([]net.IP, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + + domain = strings.ToLower(dns.Fqdn(domain)) + + // Check exact match first + if rs, exists := s.exact[domain]; exists { + var ips []net.IP + if recordType == RecordTypeA { + ips = rs.A + } else { + ips = rs.AAAA + } + if len(ips) > 0 { + out := make([]net.IP, len(ips)) + copy(out, ips) + return out, true + } + // Domain exists but no records of this type + return nil, true + } + + // Check wildcard matches + var records []net.IP + matched := false + for pattern, rs := range s.wildcards { + if !matchWildcard(pattern, domain) { + continue + } + matched = true + if recordType == RecordTypeA { + records = append(records, rs.A...) + } else { + records = append(records, rs.AAAA...) + } + } + + if !matched { + return nil, false + } + if len(records) == 0 { + return nil, true + } + out := make([]net.IP, len(records)) + copy(out, records) + return out, true } // GetPTRRecord returns the domain name for a PTR record query @@ -288,34 +294,30 @@ func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool { s.mu.RLock() defer s.mu.RUnlock() - // Normalize domain to lowercase FQDN domain = strings.ToLower(dns.Fqdn(domain)) - switch recordType { - case RecordTypeA: - // Check exact match - if _, ok := s.aRecords[domain]; ok { + // Check exact match + if rs, exists := s.exact[domain]; exists { + if recordType == RecordTypeA && len(rs.A) > 0 { return true } - // Check wildcard patterns - for pattern := range s.aWildcards { - if matchWildcard(pattern, domain) { - return true - } - } - case RecordTypeAAAA: - // Check exact match - if _, ok := s.aaaaRecords[domain]; ok { + if recordType == RecordTypeAAAA && len(rs.AAAA) > 0 { return true } - // Check wildcard patterns - for pattern := range s.aaaaWildcards { - if matchWildcard(pattern, domain) { - return true - } - } } + // Check wildcard matches + for pattern, rs := range s.wildcards { + if !matchWildcard(pattern, domain) { + continue + } + if recordType == RecordTypeA && len(rs.A) > 0 { + return true + } + if recordType == RecordTypeAAAA && len(rs.AAAA) > 0 { + return true + } + } return false } @@ -339,10 +341,8 @@ func (s *DNSRecordStore) Clear() { s.mu.Lock() defer s.mu.Unlock() - s.aRecords = make(map[string][]net.IP) - s.aaaaRecords = make(map[string][]net.IP) - s.aWildcards = make(map[string][]net.IP) - s.aaaaWildcards = make(map[string][]net.IP) + s.exact = make(map[string]*recordSet) + s.wildcards = make(map[string]*recordSet) s.ptrRecords = make(map[string]string) } @@ -494,4 +494,4 @@ func IPToReverseDNS(ip net.IP) string { } return "" -} \ No newline at end of file +} diff --git a/dns/dns_records_test.go b/dns/dns_records_test.go index eae9372..0b4481d 100644 --- a/dns/dns_records_test.go +++ b/dns/dns_records_test.go @@ -170,38 +170,47 @@ func TestDNSRecordStoreWildcard(t *testing.T) { // Add wildcard records wildcardIP := net.ParseIP("10.0.0.1") - err := store.AddRecord("*.autoco.internal", wildcardIP) + err := store.AddRecord("*.autoco.internal", wildcardIP, 0) if err != nil { t.Fatalf("Failed to add wildcard record: %v", err) } // Add exact record exactIP := net.ParseIP("10.0.0.2") - err = store.AddRecord("exact.autoco.internal", exactIP) + err = store.AddRecord("exact.autoco.internal", exactIP, 0) if err != nil { t.Fatalf("Failed to add exact record: %v", err) } // Test exact match takes precedence - ips := store.GetRecords("exact.autoco.internal.", RecordTypeA) + ips, exists := store.GetRecords("exact.autoco.internal.", RecordTypeA) + if !exists { + t.Error("Expected domain to exist") + } if len(ips) != 1 { t.Errorf("Expected 1 IP for exact match, got %d", len(ips)) } - if !ips[0].Equal(exactIP) { + if len(ips) > 0 && !ips[0].Equal(exactIP) { t.Errorf("Expected exact IP %v, got %v", exactIP, ips[0]) } // Test wildcard match - ips = store.GetRecords("host.autoco.internal.", RecordTypeA) + ips, exists = store.GetRecords("host.autoco.internal.", RecordTypeA) + if !exists { + t.Error("Expected wildcard match to exist") + } if len(ips) != 1 { t.Errorf("Expected 1 IP for wildcard match, got %d", len(ips)) } - if !ips[0].Equal(wildcardIP) { + if len(ips) > 0 && !ips[0].Equal(wildcardIP) { t.Errorf("Expected wildcard IP %v, got %v", wildcardIP, ips[0]) } // Test non-match (base domain) - ips = store.GetRecords("autoco.internal.", RecordTypeA) + ips, exists = store.GetRecords("autoco.internal.", RecordTypeA) + if exists { + t.Error("Expected base domain to not exist") + } if len(ips) != 0 { t.Errorf("Expected 0 IPs for base domain, got %d", len(ips)) } @@ -212,13 +221,16 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) { // Add complex wildcard pattern ip1 := net.ParseIP("10.0.0.1") - err := store.AddRecord("*.host-0?.autoco.internal", ip1) + err := store.AddRecord("*.host-0?.autoco.internal", ip1, 0) if err != nil { t.Fatalf("Failed to add wildcard record: %v", err) } // Test matching domain - ips := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA) + ips, exists := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA) + if !exists { + t.Error("Expected complex wildcard match to exist") + } if len(ips) != 1 { t.Errorf("Expected 1 IP for complex wildcard match, got %d", len(ips)) } @@ -227,13 +239,19 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) { } // Test non-matching domain (missing prefix) - ips = store.GetRecords("host-01.autoco.internal.", RecordTypeA) + ips, exists = store.GetRecords("host-01.autoco.internal.", RecordTypeA) + if exists { + t.Error("Expected domain without prefix to not exist") + } if len(ips) != 0 { t.Errorf("Expected 0 IPs for domain without prefix, got %d", len(ips)) } // Test non-matching domain (wrong ? position) - ips = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA) + ips, exists = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA) + if exists { + t.Error("Expected domain with wrong ? match to not exist") + } if len(ips) != 0 { t.Errorf("Expected 0 IPs for domain with wrong ? match, got %d", len(ips)) } @@ -244,13 +262,16 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) { // Add wildcard record ip := net.ParseIP("10.0.0.1") - err := store.AddRecord("*.autoco.internal", ip) + err := store.AddRecord("*.autoco.internal", ip, 0) if err != nil { t.Fatalf("Failed to add wildcard record: %v", err) } // Verify it exists - ips := store.GetRecords("host.autoco.internal.", RecordTypeA) + ips, exists := store.GetRecords("host.autoco.internal.", RecordTypeA) + if !exists { + t.Error("Expected domain to exist before removal") + } if len(ips) != 1 { t.Errorf("Expected 1 IP before removal, got %d", len(ips)) } @@ -259,7 +280,10 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) { store.RemoveRecord("*.autoco.internal", nil) // Verify it's gone - ips = store.GetRecords("host.autoco.internal.", RecordTypeA) + ips, exists = store.GetRecords("host.autoco.internal.", RecordTypeA) + if exists { + t.Error("Expected domain to not exist after removal") + } if len(ips) != 0 { t.Errorf("Expected 0 IPs after removal, got %d", len(ips)) } @@ -273,36 +297,36 @@ func TestDNSRecordStoreMultipleWildcards(t *testing.T) { ip2 := net.ParseIP("10.0.0.2") ip3 := net.ParseIP("10.0.0.3") - err := store.AddRecord("*.prod.autoco.internal", ip1) + err := store.AddRecord("*.prod.autoco.internal", ip1, 0) if err != nil { t.Fatalf("Failed to add first wildcard: %v", err) } - err = store.AddRecord("*.dev.autoco.internal", ip2) + err = store.AddRecord("*.dev.autoco.internal", ip2, 0) if err != nil { t.Fatalf("Failed to add second wildcard: %v", err) } // Add a broader wildcard that matches both - err = store.AddRecord("*.autoco.internal", ip3) + err = store.AddRecord("*.autoco.internal", ip3, 0) if err != nil { t.Fatalf("Failed to add third wildcard: %v", err) } // Test domain matching only the prod pattern and the broad pattern - ips := store.GetRecords("host.prod.autoco.internal.", RecordTypeA) + ips, _ := store.GetRecords("host.prod.autoco.internal.", RecordTypeA) if len(ips) != 2 { t.Errorf("Expected 2 IPs (prod + broad), got %d", len(ips)) } // Test domain matching only the dev pattern and the broad pattern - ips = store.GetRecords("service.dev.autoco.internal.", RecordTypeA) + ips, _ = store.GetRecords("service.dev.autoco.internal.", RecordTypeA) if len(ips) != 2 { t.Errorf("Expected 2 IPs (dev + broad), got %d", len(ips)) } // Test domain matching only the broad pattern - ips = store.GetRecords("host.test.autoco.internal.", RecordTypeA) + ips, _ = store.GetRecords("host.test.autoco.internal.", RecordTypeA) if len(ips) != 1 { t.Errorf("Expected 1 IP (broad only), got %d", len(ips)) } @@ -313,13 +337,13 @@ func TestDNSRecordStoreIPv6Wildcard(t *testing.T) { // Add IPv6 wildcard record ip := net.ParseIP("2001:db8::1") - err := store.AddRecord("*.autoco.internal", ip) + err := store.AddRecord("*.autoco.internal", ip, 0) if err != nil { t.Fatalf("Failed to add IPv6 wildcard record: %v", err) } // Test wildcard match for IPv6 - ips := store.GetRecords("host.autoco.internal.", RecordTypeAAAA) + ips, _ := store.GetRecords("host.autoco.internal.", RecordTypeAAAA) if len(ips) != 1 { t.Errorf("Expected 1 IPv6 for wildcard match, got %d", len(ips)) } @@ -333,7 +357,7 @@ func TestHasRecordWildcard(t *testing.T) { // Add wildcard record ip := net.ParseIP("10.0.0.1") - err := store.AddRecord("*.autoco.internal", ip) + err := store.AddRecord("*.autoco.internal", ip, 0) if err != nil { t.Fatalf("Failed to add wildcard record: %v", err) } @@ -354,7 +378,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) { // Add record with mixed case ip := net.ParseIP("10.0.0.1") - err := store.AddRecord("MyHost.AutoCo.Internal", ip) + err := store.AddRecord("MyHost.AutoCo.Internal", ip, 0) if err != nil { t.Fatalf("Failed to add mixed case record: %v", err) } @@ -368,7 +392,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) { } for _, domain := range testCases { - ips := store.GetRecords(domain, RecordTypeA) + ips, _ := store.GetRecords(domain, RecordTypeA) if len(ips) != 1 { t.Errorf("Expected 1 IP for domain %q, got %d", domain, len(ips)) } @@ -379,7 +403,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) { // Test wildcard with mixed case wildcardIP := net.ParseIP("10.0.0.2") - err = store.AddRecord("*.Example.Com", wildcardIP) + err = store.AddRecord("*.Example.Com", wildcardIP, 0) if err != nil { t.Fatalf("Failed to add mixed case wildcard: %v", err) } @@ -392,7 +416,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) { } for _, domain := range wildcardTestCases { - ips := store.GetRecords(domain, RecordTypeA) + ips, _ := store.GetRecords(domain, RecordTypeA) if len(ips) != 1 { t.Errorf("Expected 1 IP for wildcard domain %q, got %d", domain, len(ips)) } @@ -403,7 +427,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) { // Test removal with different case store.RemoveRecord("MYHOST.AUTOCO.INTERNAL", nil) - ips := store.GetRecords("myhost.autoco.internal.", RecordTypeA) + ips, _ := store.GetRecords("myhost.autoco.internal.", RecordTypeA) if len(ips) != 0 { t.Errorf("Expected 0 IPs after removal, got %d", len(ips)) } @@ -665,7 +689,7 @@ func TestClearPTRRecords(t *testing.T) { store.AddPTRRecord(ip2, "host2.example.com.") // Add some A records too - store.AddRecord("test.example.com.", net.ParseIP("10.0.0.1")) + store.AddRecord("test.example.com.", net.ParseIP("10.0.0.1"), 0) // Verify PTR records exist if !store.HasPTRRecord("1.1.168.192.in-addr.arpa.") { @@ -695,7 +719,7 @@ func TestAutomaticPTRRecordOnAdd(t *testing.T) { // Add an A record - should automatically add PTR record domain := "host.example.com." ip := net.ParseIP("192.168.1.100") - err := store.AddRecord(domain, ip) + err := store.AddRecord(domain, ip, 0) if err != nil { t.Fatalf("Failed to add A record: %v", err) } @@ -713,7 +737,7 @@ func TestAutomaticPTRRecordOnAdd(t *testing.T) { // Add AAAA record - should also automatically add PTR record domain6 := "ipv6host.example.com." ip6 := net.ParseIP("2001:db8::1") - err = store.AddRecord(domain6, ip6) + err = store.AddRecord(domain6, ip6, 0) if err != nil { t.Fatalf("Failed to add AAAA record: %v", err) } @@ -735,7 +759,7 @@ func TestAutomaticPTRRecordOnRemove(t *testing.T) { // Add an A record (with automatic PTR) domain := "host.example.com." ip := net.ParseIP("192.168.1.100") - store.AddRecord(domain, ip) + store.AddRecord(domain, ip, 0) // Verify PTR exists reverseDomain := "100.1.168.192.in-addr.arpa." @@ -752,7 +776,7 @@ func TestAutomaticPTRRecordOnRemove(t *testing.T) { } // Verify A record is also gone - ips := store.GetRecords(domain, RecordTypeA) + ips, _ := store.GetRecords(domain, RecordTypeA) if len(ips) != 0 { t.Errorf("Expected A record to be removed, got %d records", len(ips)) } @@ -765,8 +789,8 @@ func TestAutomaticPTRRecordOnRemoveAll(t *testing.T) { domain := "host.example.com." ip1 := net.ParseIP("192.168.1.100") ip2 := net.ParseIP("192.168.1.101") - store.AddRecord(domain, ip1) - store.AddRecord(domain, ip2) + store.AddRecord(domain, ip1, 0) + store.AddRecord(domain, ip2, 0) // Verify both PTR records exist reverseDomain1 := "100.1.168.192.in-addr.arpa." @@ -796,7 +820,7 @@ func TestNoPTRForWildcardRecords(t *testing.T) { // Add wildcard record - should NOT create PTR record domain := "*.example.com." ip := net.ParseIP("192.168.1.100") - err := store.AddRecord(domain, ip) + err := store.AddRecord(domain, ip, 0) if err != nil { t.Fatalf("Failed to add wildcard record: %v", err) } @@ -820,7 +844,7 @@ func TestPTRRecordOverwrite(t *testing.T) { // Add first domain with IP domain1 := "host1.example.com." ip := net.ParseIP("192.168.1.100") - store.AddRecord(domain1, ip) + store.AddRecord(domain1, ip, 0) // Verify PTR points to first domain reverseDomain := "100.1.168.192.in-addr.arpa." @@ -834,7 +858,7 @@ func TestPTRRecordOverwrite(t *testing.T) { // Add second domain with same IP - should overwrite PTR domain2 := "host2.example.com." - store.AddRecord(domain2, ip) + store.AddRecord(domain2, ip, 0) // Verify PTR now points to second domain (last one added) result, ok = store.GetPTRRecord(reverseDomain) diff --git a/main.go b/main.go index 2bf8dcd..6ea7d11 100644 --- a/main.go +++ b/main.go @@ -190,7 +190,7 @@ func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCt os.Exit(0) } - olmVersion := "version_replaceme" + olmVersion := "1.4.3" if showVersion { fmt.Println("Olm version " + olmVersion) os.Exit(0) diff --git a/olm.iss b/olm.iss index 1893f8e..4216d88 100644 --- a/olm.iss +++ b/olm.iss @@ -32,7 +32,7 @@ DefaultGroupName={#MyAppName} DisableProgramGroupPage=yes ; Uncomment the following line to run in non administrative install mode (install for current user only). ;PrivilegesRequired=lowest -OutputBaseFilename=mysetup +OutputBaseFilename=olm_windows_installer SolidCompression=yes WizardStyle=modern ; Add this to ensure PATH changes are applied and the system is prompted for a restart if needed @@ -78,7 +78,7 @@ begin Result := True; exit; end; - + // Perform a case-insensitive check to see if the path is already present. // We add semicolons to prevent partial matches (e.g., matching C:\App in C:\App2). if Pos(';' + UpperCase(Path) + ';', ';' + UpperCase(OrigPath) + ';') > 0 then @@ -109,7 +109,7 @@ begin PathList.Delimiter := ';'; PathList.StrictDelimiter := True; PathList.DelimitedText := OrigPath; - + // Find and remove the matching entry (case-insensitive) for I := PathList.Count - 1 downto 0 do begin @@ -119,10 +119,10 @@ begin PathList.Delete(I); end; end; - + // Reconstruct the PATH NewPath := PathList.DelimitedText; - + // Write the new PATH back to the registry if RegWriteExpandStringValue(HKEY_LOCAL_MACHINE, 'SYSTEM\CurrentControlSet\Control\Session Manager\Environment', @@ -145,8 +145,8 @@ begin // Get the application installation path AppPath := ExpandConstant('{app}'); Log('Removing PATH entry for: ' + AppPath); - + // Remove only our path entry from the system PATH RemovePathEntry(AppPath); end; -end; +end; \ No newline at end of file diff --git a/olm/connect.go b/olm/connect.go index d2c477f..3a2000c 100644 --- a/olm/connect.go +++ b/olm/connect.go @@ -7,6 +7,7 @@ import ( "runtime" "strconv" "strings" + "time" "github.com/fosrl/newt/logger" "github.com/fosrl/newt/network" @@ -173,16 +174,20 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) { for i := range wgData.Sites { site := wgData.Sites[i] - var siteEndpoint string - // here we are going to take the relay endpoint if it exists which means we requested a relay for this peer - if site.RelayEndpoint != "" { - siteEndpoint = site.RelayEndpoint - } else { - siteEndpoint = site.Endpoint + + if site.PublicKey != "" { + var siteEndpoint string + // here we are going to take the relay endpoint if it exists which means we requested a relay for this peer + if site.RelayEndpoint != "" { + siteEndpoint = site.RelayEndpoint + } else { + siteEndpoint = site.Endpoint + } + + o.apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false) } - o.apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false) - + // we still call this to add the aliases for jit lookup but we just do that then pass inside. need to skip the above so we dont add to the api if err := o.peerManager.AddPeer(site); err != nil { logger.Error("Failed to add peer: %v", err) return @@ -197,6 +202,36 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) { logger.Error("Failed to start DNS proxy: %v", err) } + // Register JIT handler: when the DNS proxy resolves a local record, check whether + // the owning site is already connected and, if not, initiate a JIT connection. + o.dnsProxy.SetJITHandler(func(siteId int) { + if o.peerManager == nil || o.websocket == nil { + return + } + + // Site already has an active peer connection - nothing to do. + if _, exists := o.peerManager.GetPeer(siteId); exists { + return + } + + o.peerSendMu.Lock() + defer o.peerSendMu.Unlock() + + // A JIT request for this site is already in-flight - avoid duplicate sends. + if _, pending := o.jitPendingSites[siteId]; pending { + return + } + + chainId := generateChainId() + logger.Info("DNS-triggered JIT connect for site %d (chainId=%s)", siteId, chainId) + stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/init", map[string]interface{}{ + "siteId": siteId, + "chainId": chainId, + }, 2*time.Second, 10) + o.stopPeerInits[chainId] = stopFunc + o.jitPendingSites[siteId] = chainId + }) + if o.tunnelConfig.OverrideDNS { // Set up DNS override to use our DNS proxy if err := dnsOverride.SetupDNSOverride(o.tunnelConfig.InterfaceName, o.dnsProxy.GetProxyIP()); err != nil { @@ -274,12 +309,12 @@ func (o *Olm) handleTerminate(msg websocket.WSMessage) { logger.Error("Error unmarshaling terminate error data: %v", err) } else { logger.Info("Terminate reason (code: %s): %s", errorData.Code, errorData.Message) - + if errorData.Code == "TERMINATED_INACTIVITY" { logger.Info("Ignoring...") return } - + // Set the olm error in the API server so it can be exposed via status o.apiServer.SetOlmError(errorData.Code, errorData.Message) } diff --git a/olm/data.go b/olm/data.go index 8bd0997..d0e6d5b 100644 --- a/olm/data.go +++ b/olm/data.go @@ -2,6 +2,7 @@ package olm import ( "encoding/json" + "fmt" "time" "github.com/fosrl/newt/holepunch" @@ -220,6 +221,7 @@ func (o *Olm) handleSync(msg websocket.WSMessage) { logger.Info("Sync: Adding new peer for site %d", siteId) o.holePunchManager.TriggerHolePunch() + o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud // // TODO: do we need to send the message to the cloud to add the peer that way? // if err := o.peerManager.AddPeer(expectedSite); err != nil { @@ -230,9 +232,17 @@ func (o *Olm) handleSync(msg websocket.WSMessage) { // add the peer via the server // this is important because newt needs to get triggered as well to add the peer once the hp is complete - o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ - "siteId": expectedSite.SiteId, - }, 1*time.Second, 10) + chainId := fmt.Sprintf("sync-%d", expectedSite.SiteId) + o.peerSendMu.Lock() + if stop, ok := o.stopPeerSends[chainId]; ok { + stop() + } + stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ + "siteId": expectedSite.SiteId, + "chainId": chainId, + }, 2*time.Second, 10) + o.stopPeerSends[chainId] = stopFunc + o.peerSendMu.Unlock() } else { // Existing peer - check if update is needed diff --git a/olm/olm.go b/olm/olm.go index 56998a6..a458f8a 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -2,6 +2,8 @@ package olm import ( "context" + "crypto/rand" + "encoding/hex" "fmt" "net" "net/http" @@ -65,7 +67,10 @@ type Olm struct { stopRegister func() updateRegister func(newData any) - stopPeerSend func() + stopPeerSends map[string]func() + stopPeerInits map[string]func() + jitPendingSites map[int]string // siteId -> chainId for in-flight JIT requests + peerSendMu sync.Mutex // WaitGroup to track tunnel lifecycle tunnelWg sync.WaitGroup @@ -116,6 +121,13 @@ func (o *Olm) initTunnelInfo(clientID string) error { return nil } +// generateChainId generates a random chain ID for tracking peer sender lifecycles. +func generateChainId() string { + b := make([]byte, 8) + _, _ = rand.Read(b) + return hex.EncodeToString(b) +} + func Init(ctx context.Context, config OlmConfig) (*Olm, error) { logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel)) @@ -166,10 +178,13 @@ func Init(ctx context.Context, config OlmConfig) (*Olm, error) { apiServer.SetAgent(config.Agent) newOlm := &Olm{ - logFile: logFile, - olmCtx: ctx, - apiServer: apiServer, - olmConfig: config, + logFile: logFile, + olmCtx: ctx, + apiServer: apiServer, + olmConfig: config, + stopPeerSends: make(map[string]func()), + stopPeerInits: make(map[string]func()), + jitPendingSites: make(map[int]string), } newOlm.registerAPICallbacks() @@ -284,6 +299,21 @@ func (o *Olm) registerAPICallbacks() { logger.Info("Processing power mode change request via API: mode=%s", req.Mode) return o.SetPowerMode(req.Mode) }, + func(req api.JITConnectionRequest) error { + logger.Info("Processing JIT connect request via API: site=%s resource=%s", req.Site, req.Resource) + + chainId := generateChainId() + o.peerSendMu.Lock() + stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/init", map[string]interface{}{ + "siteId": req.Site, + "resourceId": req.Resource, + "chainId": chainId, + }, 2*time.Second, 10) + o.stopPeerInits[chainId] = stopFunc + o.peerSendMu.Unlock() + + return nil + }, ) } @@ -345,7 +375,6 @@ func (o *Olm) StartTunnel(config TunnelConfig) { config.OrgID, config.Endpoint, 30*time.Second, // 30 seconds - config.PingTimeoutDuration, websocket.WithPingDataProvider(func() map[string]any { o.metaMu.Lock() defer o.metaMu.Unlock() @@ -385,6 +414,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) { // Handler for peer handshake - adds exit node to holepunch rotation and notifies server o.websocket.RegisterHandler("olm/wg/peer/holepunch/site/add", o.handleWgPeerHolepunchAddSite) + o.websocket.RegisterHandler("olm/wg/peer/chain/cancel", o.handleCancelChain) o.websocket.RegisterHandler("olm/sync", o.handleSync) o.websocket.OnConnect(func() error { @@ -427,7 +457,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) { "userToken": userToken, "fingerprint": o.fingerprint, "postures": o.postures, - }, 1*time.Second, 10) + }, 2*time.Second, 10) // Invoke onRegistered callback if configured if o.olmConfig.OnRegistered != nil { @@ -524,6 +554,23 @@ func (o *Olm) Close() { o.stopRegister = nil } + // Stop all pending peer init and send senders before closing websocket + o.peerSendMu.Lock() + for _, stop := range o.stopPeerInits { + if stop != nil { + stop() + } + } + o.stopPeerInits = make(map[string]func()) + for _, stop := range o.stopPeerSends { + if stop != nil { + stop() + } + } + o.stopPeerSends = make(map[string]func()) + o.jitPendingSites = make(map[int]string) + o.peerSendMu.Unlock() + // send a disconnect message to the cloud to show disconnected if o.websocket != nil { o.websocket.SendMessage("olm/disconnecting", map[string]any{}) diff --git a/olm/peer.go b/olm/peer.go index 0e2d2da..fca47b5 100644 --- a/olm/peer.go +++ b/olm/peer.go @@ -20,31 +20,51 @@ func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) { return } - if o.stopPeerSend != nil { - o.stopPeerSend() - o.stopPeerSend = nil - } - jsonData, err := json.Marshal(msg.Data) if err != nil { logger.Error("Error marshaling data: %v", err) return } - var siteConfig peers.SiteConfig - if err := json.Unmarshal(jsonData, &siteConfig); err != nil { + var siteConfigMsg struct { + peers.SiteConfig + ChainId string `json:"chainId"` + } + if err := json.Unmarshal(jsonData, &siteConfigMsg); err != nil { logger.Error("Error unmarshaling add data: %v", err) return } + if siteConfigMsg.ChainId != "" { + o.peerSendMu.Lock() + if stop, ok := o.stopPeerSends[siteConfigMsg.ChainId]; ok { + stop() + delete(o.stopPeerSends, siteConfigMsg.ChainId) + } + o.peerSendMu.Unlock() + } else { + // stop all of the stopPeerSends + o.peerSendMu.Lock() + for _, stop := range o.stopPeerSends { + stop() + } + o.stopPeerSends = make(map[string]func()) + o.peerSendMu.Unlock() + } + + if siteConfigMsg.PublicKey == "" { + logger.Warn("Skipping add-peer for site %d (%s): no public key available (site may not be connected)", siteConfigMsg.SiteId, siteConfigMsg.Name) + return + } + _ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it - if err := o.peerManager.AddPeer(siteConfig); err != nil { + if err := o.peerManager.AddPeer(siteConfigMsg.SiteConfig); err != nil { logger.Error("Failed to add peer: %v", err) return } - logger.Info("Successfully added peer for site %d", siteConfig.SiteId) + logger.Info("Successfully added peer for site %d", siteConfigMsg.SiteId) } func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) { @@ -164,13 +184,21 @@ func (o *Olm) handleWgPeerRelay(msg websocket.WSMessage) { return } - var relayData peers.RelayPeerData + var relayData struct { + peers.RelayPeerData + ChainId string `json:"chainId"` + } if err := json.Unmarshal(jsonData, &relayData); err != nil { logger.Error("Error unmarshaling relay data: %v", err) return } + if monitor := o.peerManager.GetPeerMonitor(); monitor != nil { + monitor.CancelRelaySend(relayData.ChainId) + } + primaryRelay, err := util.ResolveDomainUpstream(relayData.RelayEndpoint, o.tunnelConfig.PublicDNS) + if err != nil { logger.Error("Failed to resolve primary relay endpoint: %v", err) return @@ -197,13 +225,21 @@ func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) { return } - var relayData peers.UnRelayPeerData + var relayData struct { + peers.UnRelayPeerData + ChainId string `json:"chainId"` + } if err := json.Unmarshal(jsonData, &relayData); err != nil { logger.Error("Error unmarshaling relay data: %v", err) return } + if monitor := o.peerManager.GetPeerMonitor(); monitor != nil { + monitor.CancelRelaySend(relayData.ChainId) + } + primaryRelay, err := util.ResolveDomainUpstream(relayData.Endpoint, o.tunnelConfig.PublicDNS) + if err != nil { logger.Warn("Failed to resolve primary relay endpoint: %v", err) } @@ -230,7 +266,8 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) { } var handshakeData struct { - SiteId int `json:"siteId"` + SiteId int `json:"siteId"` + ChainId string `json:"chainId"` ExitNode struct { PublicKey string `json:"publicKey"` Endpoint string `json:"endpoint"` @@ -243,6 +280,27 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) { return } + // Stop the peer init sender for this chain, if any + if handshakeData.ChainId != "" { + o.peerSendMu.Lock() + if stop, ok := o.stopPeerInits[handshakeData.ChainId]; ok { + stop() + delete(o.stopPeerInits, handshakeData.ChainId) + } + // If this chain was initiated by a DNS-triggered JIT request, clear the + // pending entry so the site can be re-triggered if needed in the future. + delete(o.jitPendingSites, handshakeData.SiteId) + o.peerSendMu.Unlock() + } else { + // Stop all of the stopPeerInits + o.peerSendMu.Lock() + for _, stop := range o.stopPeerInits { + stop() + } + o.stopPeerInits = make(map[string]func()) + o.peerSendMu.Unlock() + } + // Get existing peer from PeerManager _, exists := o.peerManager.GetPeer(handshakeData.SiteId) if exists { @@ -273,10 +331,72 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) { o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud - // Send handshake acknowledgment back to server with retry - o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ - "siteId": handshakeData.SiteId, - }, 1*time.Second, 10) + // Send handshake acknowledgment back to server with retry, keyed by chainId + chainId := handshakeData.ChainId + if chainId == "" { + chainId = generateChainId() + } + o.peerSendMu.Lock() + stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ + "siteId": handshakeData.SiteId, + "chainId": chainId, + }, 2*time.Second, 10) + o.stopPeerSends[chainId] = stopFunc + o.peerSendMu.Unlock() logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint) } + +func (o *Olm) handleCancelChain(msg websocket.WSMessage) { + logger.Debug("Received cancel-chain message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling cancel-chain data: %v", err) + return + } + + var cancelData struct { + ChainId string `json:"chainId"` + } + if err := json.Unmarshal(jsonData, &cancelData); err != nil { + logger.Error("Error unmarshaling cancel-chain data: %v", err) + return + } + + if cancelData.ChainId == "" { + logger.Warn("Received cancel-chain message with no chainId") + return + } + + o.peerSendMu.Lock() + defer o.peerSendMu.Unlock() + + found := false + + if stop, ok := o.stopPeerInits[cancelData.ChainId]; ok { + stop() + delete(o.stopPeerInits, cancelData.ChainId) + found = true + } + // If this chain was a DNS-triggered JIT request, clear the pending entry so + // the site can be re-triggered on the next DNS lookup. + for siteId, chainId := range o.jitPendingSites { + if chainId == cancelData.ChainId { + delete(o.jitPendingSites, siteId) + break + } + } + + if stop, ok := o.stopPeerSends[cancelData.ChainId]; ok { + stop() + delete(o.stopPeerSends, cancelData.ChainId) + found = true + } + + if found { + logger.Info("Cancelled chain %s", cancelData.ChainId) + } else { + logger.Warn("Cancel-chain: no active sender found for chain %s", cancelData.ChainId) + } +} diff --git a/peers/manager.go b/peers/manager.go index 514c0af..9cc1e75 100644 --- a/peers/manager.go +++ b/peers/manager.go @@ -110,6 +110,19 @@ func (pm *PeerManager) GetAllPeers() []SiteConfig { func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error { pm.mu.Lock() defer pm.mu.Unlock() + + for _, alias := range siteConfig.Aliases { + address := net.ParseIP(alias.AliasAddress) + if address == nil { + continue + } + pm.dnsProxy.AddDNSRecord(alias.Alias, address, siteConfig.SiteId) + } + + if siteConfig.PublicKey == "" { + logger.Debug("Skip adding site %d because no pub key", siteConfig.SiteId) + return nil + } // build the allowed IPs list from the remote subnets and aliases and add them to the peer allowedIPs := make([]string, 0, len(siteConfig.RemoteSubnets)+len(siteConfig.Aliases)) @@ -143,14 +156,7 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error { if err := network.AddRoutes(siteConfig.RemoteSubnets, pm.interfaceName); err != nil { logger.Error("Failed to add routes for remote subnets: %v", err) } - for _, alias := range siteConfig.Aliases { - address := net.ParseIP(alias.AliasAddress) - if address == nil { - continue - } - pm.dnsProxy.AddDNSRecord(alias.Alias, address) - } - + monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0] monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port @@ -437,7 +443,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error { if address == nil { continue } - pm.dnsProxy.AddDNSRecord(alias.Alias, address) + pm.dnsProxy.AddDNSRecord(alias.Alias, address, siteConfig.SiteId) } pm.peerMonitor.UpdateHolepunchEndpoint(siteConfig.SiteId, siteConfig.Endpoint) @@ -717,7 +723,7 @@ func (pm *PeerManager) AddAlias(siteId int, alias Alias) error { address := net.ParseIP(alias.AliasAddress) if address != nil { - pm.dnsProxy.AddDNSRecord(alias.Alias, address) + pm.dnsProxy.AddDNSRecord(alias.Alias, address, siteId) } // Add an allowed IP for the alias diff --git a/peers/monitor/monitor.go b/peers/monitor/monitor.go index cfea418..56dcee4 100644 --- a/peers/monitor/monitor.go +++ b/peers/monitor/monitor.go @@ -2,6 +2,8 @@ package monitor import ( "context" + "crypto/rand" + "encoding/hex" "fmt" "net" "net/netip" @@ -31,11 +33,15 @@ type PeerMonitor struct { monitors map[int]*Client mutex sync.Mutex running bool - timeout time.Duration + timeout time.Duration maxAttempts int wsClient *websocket.Client publicDNS []string + // Relay sender tracking + relaySends map[string]func() + relaySendMu sync.Mutex + // Netstack fields middleDev *middleDevice.MiddleDevice localIP string @@ -48,13 +54,13 @@ type PeerMonitor struct { nsWg sync.WaitGroup // Holepunch testing fields - sharedBind *bind.SharedBind - holepunchTester *holepunch.HolepunchTester - holepunchTimeout time.Duration - holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing - holepunchStatus map[int]bool // siteID -> connected status - holepunchStopChan chan struct{} - holepunchUpdateChan chan struct{} + sharedBind *bind.SharedBind + holepunchTester *holepunch.HolepunchTester + holepunchTimeout time.Duration + holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing + holepunchStatus map[int]bool // siteID -> connected status + holepunchStopChan chan struct{} + holepunchUpdateChan chan struct{} // Relay tracking fields relayedPeers map[int]bool // siteID -> whether the peer is currently relayed @@ -83,6 +89,12 @@ type PeerMonitor struct { } // NewPeerMonitor creates a new peer monitor with the given callback +func generateChainId() string { + b := make([]byte, 8) + _, _ = rand.Read(b) + return hex.EncodeToString(b) +} + func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind, apiServer *api.API, publicDNS []string) *PeerMonitor { ctx, cancel := context.WithCancel(context.Background()) pm := &PeerMonitor{ @@ -101,6 +113,7 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe holepunchEndpoints: make(map[int]string), holepunchStatus: make(map[int]bool), relayedPeers: make(map[int]bool), + relaySends: make(map[string]func()), holepunchMaxAttempts: 3, // Trigger relay after 3 consecutive failures holepunchFailures: make(map[int]int), // Rapid initial test settings: complete within ~1.5 seconds @@ -398,20 +411,23 @@ func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status Connectio } } -// sendRelay sends a relay message to the server +// sendRelay sends a relay message to the server with retry, keyed by chainId func (pm *PeerMonitor) sendRelay(siteID int) error { if pm.wsClient == nil { return fmt.Errorf("websocket client is nil") } - err := pm.wsClient.SendMessage("olm/wg/relay", map[string]interface{}{ - "siteId": siteID, - }) - if err != nil { - logger.Error("Failed to send registration message: %v", err) - return err - } - logger.Info("Sent relay message") + chainId := generateChainId() + stopFunc, _ := pm.wsClient.SendMessageInterval("olm/wg/relay", map[string]interface{}{ + "siteId": siteID, + "chainId": chainId, + }, 2*time.Second, 10) + + pm.relaySendMu.Lock() + pm.relaySends[chainId] = stopFunc + pm.relaySendMu.Unlock() + + logger.Info("Sent relay message for site %d (chain %s)", siteID, chainId) return nil } @@ -421,23 +437,52 @@ func (pm *PeerMonitor) RequestRelay(siteID int) error { return pm.sendRelay(siteID) } -// sendUnRelay sends an unrelay message to the server +// sendUnRelay sends an unrelay message to the server with retry, keyed by chainId func (pm *PeerMonitor) sendUnRelay(siteID int) error { if pm.wsClient == nil { return fmt.Errorf("websocket client is nil") } - err := pm.wsClient.SendMessage("olm/wg/unrelay", map[string]interface{}{ - "siteId": siteID, - }) - if err != nil { - logger.Error("Failed to send registration message: %v", err) - return err - } - logger.Info("Sent unrelay message") + chainId := generateChainId() + stopFunc, _ := pm.wsClient.SendMessageInterval("olm/wg/unrelay", map[string]interface{}{ + "siteId": siteID, + "chainId": chainId, + }, 2*time.Second, 10) + + pm.relaySendMu.Lock() + pm.relaySends[chainId] = stopFunc + pm.relaySendMu.Unlock() + + logger.Info("Sent unrelay message for site %d (chain %s)", siteID, chainId) return nil } +// CancelRelaySend stops the interval sender for the given chainId, if one exists. +// If chainId is empty, all active relay senders are stopped. +func (pm *PeerMonitor) CancelRelaySend(chainId string) { + pm.relaySendMu.Lock() + defer pm.relaySendMu.Unlock() + + if chainId == "" { + for id, stop := range pm.relaySends { + if stop != nil { + stop() + } + delete(pm.relaySends, id) + } + logger.Info("Cancelled all relay senders") + return + } + + if stop, ok := pm.relaySends[chainId]; ok { + stop() + delete(pm.relaySends, chainId) + logger.Info("Cancelled relay sender for chain %s", chainId) + } else { + logger.Warn("CancelRelaySend: no active sender for chain %s", chainId) + } +} + // Stop stops monitoring all peers func (pm *PeerMonitor) Stop() { // Stop holepunch monitor first (outside of mutex to avoid deadlock) @@ -536,7 +581,7 @@ func (pm *PeerMonitor) runHolepunchMonitor() { pm.holepunchCurrentInterval = pm.holepunchMinInterval currentInterval := pm.holepunchCurrentInterval pm.mutex.Unlock() - + timer.Reset(currentInterval) logger.Debug("Holepunch monitor interval updated, reset to %v", currentInterval) case <-timer.C: @@ -679,6 +724,16 @@ func (pm *PeerMonitor) Close() { // Stop holepunch monitor first (outside of mutex to avoid deadlock) pm.stopHolepunchMonitor() + // Stop all pending relay senders + pm.relaySendMu.Lock() + for chainId, stop := range pm.relaySends { + if stop != nil { + stop() + } + delete(pm.relaySends, chainId) + } + pm.relaySendMu.Unlock() + pm.mutex.Lock() defer pm.mutex.Unlock() diff --git a/websocket/client.go b/websocket/client.go index dcf6acd..3b4e894 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -2,6 +2,7 @@ package websocket import ( "bytes" + "compress/gzip" "crypto/tls" "crypto/x509" "encoding/json" @@ -82,7 +83,6 @@ type Client struct { isDisconnected bool // Flag to track if client is intentionally disconnected reconnectMux sync.RWMutex pingInterval time.Duration - pingTimeout time.Duration onConnect func() error onTokenUpdate func(token string, exitNodes []ExitNode) onAuthError func(statusCode int, message string) // Callback for auth errors @@ -158,7 +158,7 @@ func (c *Client) OnAuthError(callback func(statusCode int, message string)) { } // NewClient creates a new websocket client -func NewClient(ID, secret, userToken, orgId, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) { +func NewClient(ID, secret, userToken, orgId, endpoint string, pingInterval time.Duration, opts ...ClientOption) (*Client, error) { config := &Config{ ID: ID, Secret: secret, @@ -175,7 +175,6 @@ func NewClient(ID, secret, userToken, orgId, endpoint string, pingInterval time. reconnectInterval: 3 * time.Second, isConnected: false, pingInterval: pingInterval, - pingTimeout: pingTimeout, clientType: "olm", pingDone: make(chan struct{}), } @@ -803,8 +802,7 @@ func (c *Client) readPumpWithDisconnectDetection() { case <-c.done: return default: - var msg WSMessage - err := c.conn.ReadJSON(&msg) + messageType, p, err := c.conn.ReadMessage() if err != nil { // Check if we're shutting down or explicitly disconnected before logging error select { @@ -829,6 +827,30 @@ func (c *Client) readPumpWithDisconnectDetection() { } } + // Decompress binary frames (gzip-compressed JSON) + var data []byte + if messageType == websocket.BinaryMessage { + gr, gzErr := gzip.NewReader(bytes.NewReader(p)) + if gzErr != nil { + logger.Error("websocket: failed to create gzip reader: %v", gzErr) + continue + } + data, gzErr = io.ReadAll(gr) + gr.Close() + if gzErr != nil { + logger.Error("websocket: failed to decompress message: %v", gzErr) + continue + } + } else { + data = p + } + + var msg WSMessage + if err = json.Unmarshal(data, &msg); err != nil { + logger.Error("websocket: failed to parse message: %v", err) + continue + } + // Update config version from incoming message c.setConfigVersion(msg.ConfigVersion)