Compare commits

..

1 Commits

Author SHA1 Message Date
dependabot[bot]
c7b590a23e Bump golang from 1.25-alpine to 1.26-alpine in the minor-updates group
Bumps the minor-updates group with 1 update: golang.


Updates `golang` from 1.25-alpine to 1.26-alpine

---
updated-dependencies:
- dependency-name: golang
  dependency-version: 1.26-alpine
  dependency-type: direct:production
  dependency-group: minor-updates
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-02-25 20:54:57 +00:00
11 changed files with 251 additions and 673 deletions

View File

@@ -1,4 +1,4 @@
FROM golang:1.25-alpine AS builder FROM golang:1.26-alpine AS builder
# Set the working directory inside the container # Set the working directory inside the container
WORKDIR /app WORKDIR /app

View File

@@ -78,13 +78,6 @@ type MetadataChangeRequest struct {
Postures map[string]any `json:"postures"` Postures map[string]any `json:"postures"`
} }
// JITConnectionRequest defines the structure for a dynamic Just-In-Time connection request.
// Either SiteID or ResourceID must be provided (but not necessarily both).
type JITConnectionRequest struct {
Site string `json:"site,omitempty"`
Resource string `json:"resource,omitempty"`
}
// API represents the HTTP server and its state // API represents the HTTP server and its state
type API struct { type API struct {
addr string addr string
@@ -99,7 +92,6 @@ type API struct {
onExit func() error onExit func() error
onRebind func() error onRebind func() error
onPowerMode func(PowerModeRequest) error onPowerMode func(PowerModeRequest) error
onJITConnect func(JITConnectionRequest) error
statusMu sync.RWMutex statusMu sync.RWMutex
peerStatuses map[int]*PeerStatus peerStatuses map[int]*PeerStatus
@@ -151,7 +143,6 @@ func (s *API) SetHandlers(
onExit func() error, onExit func() error,
onRebind func() error, onRebind func() error,
onPowerMode func(PowerModeRequest) error, onPowerMode func(PowerModeRequest) error,
onJITConnect func(JITConnectionRequest) error,
) { ) {
s.onConnect = onConnect s.onConnect = onConnect
s.onSwitchOrg = onSwitchOrg s.onSwitchOrg = onSwitchOrg
@@ -160,7 +151,6 @@ func (s *API) SetHandlers(
s.onExit = onExit s.onExit = onExit
s.onRebind = onRebind s.onRebind = onRebind
s.onPowerMode = onPowerMode s.onPowerMode = onPowerMode
s.onJITConnect = onJITConnect
} }
// Start starts the HTTP server // Start starts the HTTP server
@@ -179,7 +169,6 @@ func (s *API) Start() error {
mux.HandleFunc("/health", s.handleHealth) mux.HandleFunc("/health", s.handleHealth)
mux.HandleFunc("/rebind", s.handleRebind) mux.HandleFunc("/rebind", s.handleRebind)
mux.HandleFunc("/power-mode", s.handlePowerMode) mux.HandleFunc("/power-mode", s.handlePowerMode)
mux.HandleFunc("/jit-connect", s.handleJITConnect)
s.server = &http.Server{ s.server = &http.Server{
Handler: mux, Handler: mux,
@@ -644,54 +633,6 @@ func (s *API) handleRebind(w http.ResponseWriter, r *http.Request) {
}) })
} }
// handleJITConnect handles the /jit-connect endpoint.
// It initiates a dynamic Just-In-Time connection to a site identified by either
// a site or a resource. Exactly one of the two must be provided.
func (s *API) handleJITConnect(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req JITConnectionRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest)
return
}
// Validate that exactly one of site or resource is provided
if req.Site == "" && req.Resource == "" {
http.Error(w, "Missing required field: either site or resource must be provided", http.StatusBadRequest)
return
}
if req.Site != "" && req.Resource != "" {
http.Error(w, "Ambiguous request: provide either site or resource, not both", http.StatusBadRequest)
return
}
if req.Site != "" {
logger.Info("Received JIT connection request via API: site=%s", req.Site)
} else {
logger.Info("Received JIT connection request via API: resource=%s", req.Resource)
}
if s.onJITConnect != nil {
if err := s.onJITConnect(req); err != nil {
http.Error(w, fmt.Sprintf("JIT connection failed: %v", err), http.StatusInternalServerError)
return
}
} else {
http.Error(w, "JIT connect handler not configured", http.StatusNotImplemented)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusAccepted)
_ = json.NewEncoder(w).Encode(map[string]string{
"status": "JIT connection request accepted",
})
}
// handlePowerMode handles the /power-mode endpoint // handlePowerMode handles the /power-mode endpoint
// This allows changing the power mode between "normal" and "low" // This allows changing the power mode between "normal" and "low"
func (s *API) handlePowerMode(w http.ResponseWriter, r *http.Request) { func (s *API) handlePowerMode(w http.ResponseWriter, r *http.Request) {

View File

@@ -447,20 +447,19 @@ func (p *DNSProxy) checkLocalRecords(query *dns.Msg, question dns.Question) *dns
return nil return nil
} }
ips, exists := p.recordStore.GetRecords(question.Name, recordType) ips := p.recordStore.GetRecords(question.Name, recordType)
if !exists { if len(ips) == 0 {
// Domain not found in local records, forward to upstream
return nil return nil
} }
logger.Debug("Found %d local record(s) for %s", len(ips), question.Name) logger.Debug("Found %d local record(s) for %s", len(ips), question.Name)
// Create response message (NODATA if no records, otherwise with answers) // Create response message
response := new(dns.Msg) response := new(dns.Msg)
response.SetReply(query) response.SetReply(query)
response.Authoritative = true response.Authoritative = true
// Add answer records (loop is a no-op if ips is empty) // Add answer records
for _, ip := range ips { for _, ip := range ips {
var rr dns.RR var rr dns.RR
if question.Qtype == dns.TypeA { if question.Qtype == dns.TypeA {
@@ -731,9 +730,8 @@ func (p *DNSProxy) RemoveDNSRecord(domain string, ip net.IP) {
p.recordStore.RemoveRecord(domain, ip) p.recordStore.RemoveRecord(domain, ip)
} }
// GetDNSRecords returns all IP addresses for a domain and record type. // GetDNSRecords returns all IP addresses for a domain and record type
// The second return value indicates whether the domain exists. func (p *DNSProxy) GetDNSRecords(domain string, recordType RecordType) []net.IP {
func (p *DNSProxy) GetDNSRecords(domain string, recordType RecordType) ([]net.IP, bool) {
return p.recordStore.GetRecords(domain, recordType) return p.recordStore.GetRecords(domain, recordType)
} }

View File

@@ -1,178 +0,0 @@
package dns
import (
"net"
"testing"
"github.com/miekg/dns"
)
func TestCheckLocalRecordsNODATAForAAAA(t *testing.T) {
proxy := &DNSProxy{
recordStore: NewDNSRecordStore(),
}
// Add an A record for a domain (no AAAA record)
ip := net.ParseIP("10.0.0.1")
err := proxy.recordStore.AddRecord("myservice.internal", ip)
if err != nil {
t.Fatalf("Failed to add A record: %v", err)
}
// Query AAAA for domain with only A record - should return NODATA
query := new(dns.Msg)
query.SetQuestion("myservice.internal.", dns.TypeAAAA)
response := proxy.checkLocalRecords(query, query.Question[0])
if response == nil {
t.Fatal("Expected NODATA response, got nil (would forward to upstream)")
}
if response.Rcode != dns.RcodeSuccess {
t.Errorf("Expected Rcode NOERROR (0), got %d", response.Rcode)
}
if len(response.Answer) != 0 {
t.Errorf("Expected empty answer section for NODATA, got %d answers", len(response.Answer))
}
if !response.Authoritative {
t.Error("Expected response to be authoritative")
}
// Query A for same domain - should return the record
query = new(dns.Msg)
query.SetQuestion("myservice.internal.", dns.TypeA)
response = proxy.checkLocalRecords(query, query.Question[0])
if response == nil {
t.Fatal("Expected response with A record, got nil")
}
if len(response.Answer) != 1 {
t.Fatalf("Expected 1 answer, got %d", len(response.Answer))
}
aRecord, ok := response.Answer[0].(*dns.A)
if !ok {
t.Fatal("Expected A record in answer")
}
if !aRecord.A.Equal(ip.To4()) {
t.Errorf("Expected IP %v, got %v", ip.To4(), aRecord.A)
}
}
func TestCheckLocalRecordsNODATAForA(t *testing.T) {
proxy := &DNSProxy{
recordStore: NewDNSRecordStore(),
}
// Add an AAAA record for a domain (no A record)
ip := net.ParseIP("2001:db8::1")
err := proxy.recordStore.AddRecord("ipv6only.internal", ip)
if err != nil {
t.Fatalf("Failed to add AAAA record: %v", err)
}
// Query A for domain with only AAAA record - should return NODATA
query := new(dns.Msg)
query.SetQuestion("ipv6only.internal.", dns.TypeA)
response := proxy.checkLocalRecords(query, query.Question[0])
if response == nil {
t.Fatal("Expected NODATA response, got nil")
}
if response.Rcode != dns.RcodeSuccess {
t.Errorf("Expected Rcode NOERROR (0), got %d", response.Rcode)
}
if len(response.Answer) != 0 {
t.Errorf("Expected empty answer section, got %d answers", len(response.Answer))
}
if !response.Authoritative {
t.Error("Expected response to be authoritative")
}
// Query AAAA for same domain - should return the record
query = new(dns.Msg)
query.SetQuestion("ipv6only.internal.", dns.TypeAAAA)
response = proxy.checkLocalRecords(query, query.Question[0])
if response == nil {
t.Fatal("Expected response with AAAA record, got nil")
}
if len(response.Answer) != 1 {
t.Fatalf("Expected 1 answer, got %d", len(response.Answer))
}
aaaaRecord, ok := response.Answer[0].(*dns.AAAA)
if !ok {
t.Fatal("Expected AAAA record in answer")
}
if !aaaaRecord.AAAA.Equal(ip) {
t.Errorf("Expected IP %v, got %v", ip, aaaaRecord.AAAA)
}
}
func TestCheckLocalRecordsNonExistentDomain(t *testing.T) {
proxy := &DNSProxy{
recordStore: NewDNSRecordStore(),
}
// Add a record so the store isn't empty
err := proxy.recordStore.AddRecord("exists.internal", net.ParseIP("10.0.0.1"))
if err != nil {
t.Fatalf("Failed to add record: %v", err)
}
// Query A for non-existent domain - should return nil (forward to upstream)
query := new(dns.Msg)
query.SetQuestion("unknown.internal.", dns.TypeA)
response := proxy.checkLocalRecords(query, query.Question[0])
if response != nil {
t.Error("Expected nil for non-existent domain, got response")
}
// Query AAAA for non-existent domain - should also return nil
query = new(dns.Msg)
query.SetQuestion("unknown.internal.", dns.TypeAAAA)
response = proxy.checkLocalRecords(query, query.Question[0])
if response != nil {
t.Error("Expected nil for non-existent domain, got response")
}
}
func TestCheckLocalRecordsNODATAWildcard(t *testing.T) {
proxy := &DNSProxy{
recordStore: NewDNSRecordStore(),
}
// Add a wildcard A record (no AAAA)
ip := net.ParseIP("10.0.0.1")
err := proxy.recordStore.AddRecord("*.wildcard.internal", ip)
if err != nil {
t.Fatalf("Failed to add wildcard A record: %v", err)
}
// Query AAAA for wildcard-matched domain - should return NODATA
query := new(dns.Msg)
query.SetQuestion("host.wildcard.internal.", dns.TypeAAAA)
response := proxy.checkLocalRecords(query, query.Question[0])
if response == nil {
t.Fatal("Expected NODATA response for wildcard match, got nil")
}
if response.Rcode != dns.RcodeSuccess {
t.Errorf("Expected Rcode NOERROR (0), got %d", response.Rcode)
}
if len(response.Answer) != 0 {
t.Errorf("Expected empty answer section, got %d answers", len(response.Answer))
}
// Query A for wildcard-matched domain - should return the record
query = new(dns.Msg)
query.SetQuestion("host.wildcard.internal.", dns.TypeA)
response = proxy.checkLocalRecords(query, query.Question[0])
if response == nil {
t.Fatal("Expected response with A record, got nil")
}
if len(response.Answer) != 1 {
t.Fatalf("Expected 1 answer, got %d", len(response.Answer))
}
}

View File

@@ -18,26 +18,23 @@ const (
RecordTypePTR RecordType = RecordType(dns.TypePTR) RecordTypePTR RecordType = RecordType(dns.TypePTR)
) )
// recordSet holds A and AAAA records for a single domain or wildcard pattern // DNSRecordStore manages local DNS records for A, AAAA, and PTR queries
type recordSet struct {
A []net.IP
AAAA []net.IP
}
// DNSRecordStore manages local DNS records for A, AAAA, and PTR queries.
// Exact domains are stored in a map; wildcard patterns are in a separate map.
type DNSRecordStore struct { type DNSRecordStore struct {
mu sync.RWMutex mu sync.RWMutex
exact map[string]*recordSet // normalized FQDN -> A/AAAA records aRecords map[string][]net.IP // domain -> list of IPv4 addresses
wildcards map[string]*recordSet // wildcard pattern -> A/AAAA records aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses
aWildcards map[string][]net.IP // wildcard pattern -> list of IPv4 addresses
aaaaWildcards map[string][]net.IP // wildcard pattern -> list of IPv6 addresses
ptrRecords map[string]string // IP address string -> domain name ptrRecords map[string]string // IP address string -> domain name
} }
// NewDNSRecordStore creates a new DNS record store // NewDNSRecordStore creates a new DNS record store
func NewDNSRecordStore() *DNSRecordStore { func NewDNSRecordStore() *DNSRecordStore {
return &DNSRecordStore{ return &DNSRecordStore{
exact: make(map[string]*recordSet), aRecords: make(map[string][]net.IP),
wildcards: make(map[string]*recordSet), aaaaRecords: make(map[string][]net.IP),
aWildcards: make(map[string][]net.IP),
aaaaWildcards: make(map[string][]net.IP),
ptrRecords: make(map[string]string), ptrRecords: make(map[string]string),
} }
} }
@@ -51,37 +48,39 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
// Ensure domain ends with a dot (FQDN format)
if len(domain) == 0 || domain[len(domain)-1] != '.' { if len(domain) == 0 || domain[len(domain)-1] != '.' {
domain = domain + "." domain = domain + "."
} }
// Normalize domain to lowercase FQDN
domain = strings.ToLower(dns.Fqdn(domain)) domain = strings.ToLower(dns.Fqdn(domain))
// Check if domain contains wildcards
isWildcard := strings.ContainsAny(domain, "*?") isWildcard := strings.ContainsAny(domain, "*?")
isV4 := ip.To4() != nil if ip.To4() != nil {
if !isV4 && ip.To16() == nil { // IPv4 address
if isWildcard {
s.aWildcards[domain] = append(s.aWildcards[domain], ip)
} else {
s.aRecords[domain] = append(s.aRecords[domain], ip)
// Automatically add PTR record for non-wildcard domains
s.ptrRecords[ip.String()] = domain
}
} else if ip.To16() != nil {
// IPv6 address
if isWildcard {
s.aaaaWildcards[domain] = append(s.aaaaWildcards[domain], ip)
} else {
s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip)
// Automatically add PTR record for non-wildcard domains
s.ptrRecords[ip.String()] = domain
}
} else {
return &net.ParseError{Type: "IP address", Text: ip.String()} return &net.ParseError{Type: "IP address", Text: ip.String()}
} }
// Choose the appropriate map based on whether this is a wildcard
m := s.exact
if isWildcard {
m = s.wildcards
}
if m[domain] == nil {
m[domain] = &recordSet{}
}
rs := m[domain]
if isV4 {
rs.A = append(rs.A, ip)
} else {
rs.AAAA = append(rs.AAAA, ip)
}
// Add PTR record for non-wildcard domains
if !isWildcard {
s.ptrRecords[ip.String()] = domain
}
return nil return nil
} }
@@ -113,61 +112,88 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
// Ensure domain ends with a dot (FQDN format)
if len(domain) == 0 || domain[len(domain)-1] != '.' { if len(domain) == 0 || domain[len(domain)-1] != '.' {
domain = domain + "." domain = domain + "."
} }
// Normalize domain to lowercase FQDN
domain = strings.ToLower(dns.Fqdn(domain)) domain = strings.ToLower(dns.Fqdn(domain))
// Check if domain contains wildcards
isWildcard := strings.ContainsAny(domain, "*?") isWildcard := strings.ContainsAny(domain, "*?")
// Choose the appropriate map
m := s.exact
if isWildcard {
m = s.wildcards
}
rs := m[domain]
if rs == nil {
return
}
if ip == nil { if ip == nil {
// Remove all records for this domain // Remove all records for this domain
if !isWildcard { if isWildcard {
for _, ipAddr := range rs.A { delete(s.aWildcards, domain)
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain { delete(s.aaaaWildcards, domain)
delete(s.ptrRecords, ipAddr.String()) } else {
} // For non-wildcard domains, remove PTR records for all IPs
} if ips, ok := s.aRecords[domain]; ok {
for _, ipAddr := range rs.AAAA { for _, ipAddr := range ips {
// Only remove PTR if it points to this domain
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain { if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ipAddr.String()) delete(s.ptrRecords, ipAddr.String())
} }
} }
} }
delete(m, domain) if ips, ok := s.aaaaRecords[domain]; ok {
for _, ipAddr := range ips {
// Only remove PTR if it points to this domain
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ipAddr.String())
}
}
}
delete(s.aRecords, domain)
delete(s.aaaaRecords, domain)
}
return return
} }
// Remove specific IP
if ip.To4() != nil { if ip.To4() != nil {
rs.A = removeIP(rs.A, ip) // Remove specific IPv4 address
if !isWildcard { if isWildcard {
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain { if ips, ok := s.aWildcards[domain]; ok {
delete(s.ptrRecords, ip.String()) s.aWildcards[domain] = removeIP(ips, ip)
if len(s.aWildcards[domain]) == 0 {
delete(s.aWildcards, domain)
} }
} }
} else { } else {
rs.AAAA = removeIP(rs.AAAA, ip) if ips, ok := s.aRecords[domain]; ok {
if !isWildcard { s.aRecords[domain] = removeIP(ips, ip)
if len(s.aRecords[domain]) == 0 {
delete(s.aRecords, domain)
}
// Automatically remove PTR record if it points to this domain
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ip.String())
}
}
}
} else if ip.To16() != nil {
// Remove specific IPv6 address
if isWildcard {
if ips, ok := s.aaaaWildcards[domain]; ok {
s.aaaaWildcards[domain] = removeIP(ips, ip)
if len(s.aaaaWildcards[domain]) == 0 {
delete(s.aaaaWildcards, domain)
}
}
} else {
if ips, ok := s.aaaaRecords[domain]; ok {
s.aaaaRecords[domain] = removeIP(ips, ip)
if len(s.aaaaRecords[domain]) == 0 {
delete(s.aaaaRecords, domain)
}
// Automatically remove PTR record if it points to this domain
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain { if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ip.String()) delete(s.ptrRecords, ip.String())
} }
} }
} }
// Clean up empty record sets
if len(rs.A) == 0 && len(rs.AAAA) == 0 {
delete(m, domain)
} }
} }
@@ -179,56 +205,61 @@ func (s *DNSRecordStore) RemovePTRRecord(ip net.IP) {
delete(s.ptrRecords, ip.String()) delete(s.ptrRecords, ip.String())
} }
// GetRecords returns all IP addresses for a domain and record type. // GetRecords returns all IP addresses for a domain and record type
// The second return value indicates whether the domain exists at all // First checks for exact matches, then checks wildcard patterns
// (true = domain exists, use NODATA if no records; false = NXDOMAIN). func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP {
func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) ([]net.IP, bool) {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
// Normalize domain to lowercase FQDN
domain = strings.ToLower(dns.Fqdn(domain)) domain = strings.ToLower(dns.Fqdn(domain))
// Check exact match first
if rs, exists := s.exact[domain]; exists {
var ips []net.IP
if recordType == RecordTypeA {
ips = rs.A
} else {
ips = rs.AAAA
}
if len(ips) > 0 {
out := make([]net.IP, len(ips))
copy(out, ips)
return out, true
}
// Domain exists but no records of this type
return nil, true
}
// Check wildcard matches
var records []net.IP var records []net.IP
matched := false switch recordType {
for pattern, rs := range s.wildcards { case RecordTypeA:
if !matchWildcard(pattern, domain) { // Check exact match first
continue if ips, ok := s.aRecords[domain]; ok {
// Return a copy to prevent external modifications
records = make([]net.IP, len(ips))
copy(records, ips)
return records
} }
matched = true // Check wildcard patterns
if recordType == RecordTypeA { for pattern, ips := range s.aWildcards {
records = append(records, rs.A...) if matchWildcard(pattern, domain) {
} else { records = append(records, ips...)
records = append(records, rs.AAAA...) }
}
if len(records) > 0 {
// Return a copy
result := make([]net.IP, len(records))
copy(result, records)
return result
}
case RecordTypeAAAA:
// Check exact match first
if ips, ok := s.aaaaRecords[domain]; ok {
// Return a copy to prevent external modifications
records = make([]net.IP, len(ips))
copy(records, ips)
return records
}
// Check wildcard patterns
for pattern, ips := range s.aaaaWildcards {
if matchWildcard(pattern, domain) {
records = append(records, ips...)
}
}
if len(records) > 0 {
// Return a copy
result := make([]net.IP, len(records))
copy(result, records)
return result
} }
} }
if !matched { return records
return nil, false
}
if len(records) == 0 {
return nil, true
}
out := make([]net.IP, len(records))
copy(out, records)
return out, true
} }
// GetPTRRecord returns the domain name for a PTR record query // GetPTRRecord returns the domain name for a PTR record query
@@ -257,30 +288,34 @@ func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
// Normalize domain to lowercase FQDN
domain = strings.ToLower(dns.Fqdn(domain)) domain = strings.ToLower(dns.Fqdn(domain))
switch recordType {
case RecordTypeA:
// Check exact match // Check exact match
if rs, exists := s.exact[domain]; exists { if _, ok := s.aRecords[domain]; ok {
if recordType == RecordTypeA && len(rs.A) > 0 {
return true return true
} }
if recordType == RecordTypeAAAA && len(rs.AAAA) > 0 { // Check wildcard patterns
for pattern := range s.aWildcards {
if matchWildcard(pattern, domain) {
return true return true
} }
} }
case RecordTypeAAAA:
// Check exact match
if _, ok := s.aaaaRecords[domain]; ok {
return true
}
// Check wildcard patterns
for pattern := range s.aaaaWildcards {
if matchWildcard(pattern, domain) {
return true
}
}
}
// Check wildcard matches
for pattern, rs := range s.wildcards {
if !matchWildcard(pattern, domain) {
continue
}
if recordType == RecordTypeA && len(rs.A) > 0 {
return true
}
if recordType == RecordTypeAAAA && len(rs.AAAA) > 0 {
return true
}
}
return false return false
} }
@@ -304,8 +339,10 @@ func (s *DNSRecordStore) Clear() {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
s.exact = make(map[string]*recordSet) s.aRecords = make(map[string][]net.IP)
s.wildcards = make(map[string]*recordSet) s.aaaaRecords = make(map[string][]net.IP)
s.aWildcards = make(map[string][]net.IP)
s.aaaaWildcards = make(map[string][]net.IP)
s.ptrRecords = make(map[string]string) s.ptrRecords = make(map[string]string)
} }

View File

@@ -183,34 +183,25 @@ func TestDNSRecordStoreWildcard(t *testing.T) {
} }
// Test exact match takes precedence // Test exact match takes precedence
ips, exists := store.GetRecords("exact.autoco.internal.", RecordTypeA) ips := store.GetRecords("exact.autoco.internal.", RecordTypeA)
if !exists {
t.Error("Expected domain to exist")
}
if len(ips) != 1 { if len(ips) != 1 {
t.Errorf("Expected 1 IP for exact match, got %d", len(ips)) t.Errorf("Expected 1 IP for exact match, got %d", len(ips))
} }
if len(ips) > 0 && !ips[0].Equal(exactIP) { if !ips[0].Equal(exactIP) {
t.Errorf("Expected exact IP %v, got %v", exactIP, ips[0]) t.Errorf("Expected exact IP %v, got %v", exactIP, ips[0])
} }
// Test wildcard match // Test wildcard match
ips, exists = store.GetRecords("host.autoco.internal.", RecordTypeA) ips = store.GetRecords("host.autoco.internal.", RecordTypeA)
if !exists {
t.Error("Expected wildcard match to exist")
}
if len(ips) != 1 { if len(ips) != 1 {
t.Errorf("Expected 1 IP for wildcard match, got %d", len(ips)) t.Errorf("Expected 1 IP for wildcard match, got %d", len(ips))
} }
if len(ips) > 0 && !ips[0].Equal(wildcardIP) { if !ips[0].Equal(wildcardIP) {
t.Errorf("Expected wildcard IP %v, got %v", wildcardIP, ips[0]) t.Errorf("Expected wildcard IP %v, got %v", wildcardIP, ips[0])
} }
// Test non-match (base domain) // Test non-match (base domain)
ips, exists = store.GetRecords("autoco.internal.", RecordTypeA) ips = store.GetRecords("autoco.internal.", RecordTypeA)
if exists {
t.Error("Expected base domain to not exist")
}
if len(ips) != 0 { if len(ips) != 0 {
t.Errorf("Expected 0 IPs for base domain, got %d", len(ips)) t.Errorf("Expected 0 IPs for base domain, got %d", len(ips))
} }
@@ -227,10 +218,7 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) {
} }
// Test matching domain // Test matching domain
ips, exists := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA) ips := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA)
if !exists {
t.Error("Expected complex wildcard match to exist")
}
if len(ips) != 1 { if len(ips) != 1 {
t.Errorf("Expected 1 IP for complex wildcard match, got %d", len(ips)) t.Errorf("Expected 1 IP for complex wildcard match, got %d", len(ips))
} }
@@ -239,19 +227,13 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) {
} }
// Test non-matching domain (missing prefix) // Test non-matching domain (missing prefix)
ips, exists = store.GetRecords("host-01.autoco.internal.", RecordTypeA) ips = store.GetRecords("host-01.autoco.internal.", RecordTypeA)
if exists {
t.Error("Expected domain without prefix to not exist")
}
if len(ips) != 0 { if len(ips) != 0 {
t.Errorf("Expected 0 IPs for domain without prefix, got %d", len(ips)) t.Errorf("Expected 0 IPs for domain without prefix, got %d", len(ips))
} }
// Test non-matching domain (wrong ? position) // Test non-matching domain (wrong ? position)
ips, exists = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA) ips = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA)
if exists {
t.Error("Expected domain with wrong ? match to not exist")
}
if len(ips) != 0 { if len(ips) != 0 {
t.Errorf("Expected 0 IPs for domain with wrong ? match, got %d", len(ips)) t.Errorf("Expected 0 IPs for domain with wrong ? match, got %d", len(ips))
} }
@@ -268,10 +250,7 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) {
} }
// Verify it exists // Verify it exists
ips, exists := store.GetRecords("host.autoco.internal.", RecordTypeA) ips := store.GetRecords("host.autoco.internal.", RecordTypeA)
if !exists {
t.Error("Expected domain to exist before removal")
}
if len(ips) != 1 { if len(ips) != 1 {
t.Errorf("Expected 1 IP before removal, got %d", len(ips)) t.Errorf("Expected 1 IP before removal, got %d", len(ips))
} }
@@ -280,10 +259,7 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) {
store.RemoveRecord("*.autoco.internal", nil) store.RemoveRecord("*.autoco.internal", nil)
// Verify it's gone // Verify it's gone
ips, exists = store.GetRecords("host.autoco.internal.", RecordTypeA) ips = store.GetRecords("host.autoco.internal.", RecordTypeA)
if exists {
t.Error("Expected domain to not exist after removal")
}
if len(ips) != 0 { if len(ips) != 0 {
t.Errorf("Expected 0 IPs after removal, got %d", len(ips)) t.Errorf("Expected 0 IPs after removal, got %d", len(ips))
} }
@@ -314,19 +290,19 @@ func TestDNSRecordStoreMultipleWildcards(t *testing.T) {
} }
// Test domain matching only the prod pattern and the broad pattern // Test domain matching only the prod pattern and the broad pattern
ips, _ := store.GetRecords("host.prod.autoco.internal.", RecordTypeA) ips := store.GetRecords("host.prod.autoco.internal.", RecordTypeA)
if len(ips) != 2 { if len(ips) != 2 {
t.Errorf("Expected 2 IPs (prod + broad), got %d", len(ips)) t.Errorf("Expected 2 IPs (prod + broad), got %d", len(ips))
} }
// Test domain matching only the dev pattern and the broad pattern // Test domain matching only the dev pattern and the broad pattern
ips, _ = store.GetRecords("service.dev.autoco.internal.", RecordTypeA) ips = store.GetRecords("service.dev.autoco.internal.", RecordTypeA)
if len(ips) != 2 { if len(ips) != 2 {
t.Errorf("Expected 2 IPs (dev + broad), got %d", len(ips)) t.Errorf("Expected 2 IPs (dev + broad), got %d", len(ips))
} }
// Test domain matching only the broad pattern // Test domain matching only the broad pattern
ips, _ = store.GetRecords("host.test.autoco.internal.", RecordTypeA) ips = store.GetRecords("host.test.autoco.internal.", RecordTypeA)
if len(ips) != 1 { if len(ips) != 1 {
t.Errorf("Expected 1 IP (broad only), got %d", len(ips)) t.Errorf("Expected 1 IP (broad only), got %d", len(ips))
} }
@@ -343,7 +319,7 @@ func TestDNSRecordStoreIPv6Wildcard(t *testing.T) {
} }
// Test wildcard match for IPv6 // Test wildcard match for IPv6
ips, _ := store.GetRecords("host.autoco.internal.", RecordTypeAAAA) ips := store.GetRecords("host.autoco.internal.", RecordTypeAAAA)
if len(ips) != 1 { if len(ips) != 1 {
t.Errorf("Expected 1 IPv6 for wildcard match, got %d", len(ips)) t.Errorf("Expected 1 IPv6 for wildcard match, got %d", len(ips))
} }
@@ -392,7 +368,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
} }
for _, domain := range testCases { for _, domain := range testCases {
ips, _ := store.GetRecords(domain, RecordTypeA) ips := store.GetRecords(domain, RecordTypeA)
if len(ips) != 1 { if len(ips) != 1 {
t.Errorf("Expected 1 IP for domain %q, got %d", domain, len(ips)) t.Errorf("Expected 1 IP for domain %q, got %d", domain, len(ips))
} }
@@ -416,7 +392,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
} }
for _, domain := range wildcardTestCases { for _, domain := range wildcardTestCases {
ips, _ := store.GetRecords(domain, RecordTypeA) ips := store.GetRecords(domain, RecordTypeA)
if len(ips) != 1 { if len(ips) != 1 {
t.Errorf("Expected 1 IP for wildcard domain %q, got %d", domain, len(ips)) t.Errorf("Expected 1 IP for wildcard domain %q, got %d", domain, len(ips))
} }
@@ -427,7 +403,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
// Test removal with different case // Test removal with different case
store.RemoveRecord("MYHOST.AUTOCO.INTERNAL", nil) store.RemoveRecord("MYHOST.AUTOCO.INTERNAL", nil)
ips, _ := store.GetRecords("myhost.autoco.internal.", RecordTypeA) ips := store.GetRecords("myhost.autoco.internal.", RecordTypeA)
if len(ips) != 0 { if len(ips) != 0 {
t.Errorf("Expected 0 IPs after removal, got %d", len(ips)) t.Errorf("Expected 0 IPs after removal, got %d", len(ips))
} }
@@ -776,7 +752,7 @@ func TestAutomaticPTRRecordOnRemove(t *testing.T) {
} }
// Verify A record is also gone // Verify A record is also gone
ips, _ := store.GetRecords(domain, RecordTypeA) ips := store.GetRecords(domain, RecordTypeA)
if len(ips) != 0 { if len(ips) != 0 {
t.Errorf("Expected A record to be removed, got %d records", len(ips)) t.Errorf("Expected A record to be removed, got %d records", len(ips))
} }

View File

@@ -32,7 +32,7 @@ DefaultGroupName={#MyAppName}
DisableProgramGroupPage=yes DisableProgramGroupPage=yes
; Uncomment the following line to run in non administrative install mode (install for current user only). ; Uncomment the following line to run in non administrative install mode (install for current user only).
;PrivilegesRequired=lowest ;PrivilegesRequired=lowest
OutputBaseFilename=olm_windows_installer OutputBaseFilename=mysetup
SolidCompression=yes SolidCompression=yes
WizardStyle=modern WizardStyle=modern
; Add this to ensure PATH changes are applied and the system is prompted for a restart if needed ; Add this to ensure PATH changes are applied and the system is prompted for a restart if needed

View File

@@ -2,7 +2,6 @@ package olm
import ( import (
"encoding/json" "encoding/json"
"fmt"
"time" "time"
"github.com/fosrl/newt/holepunch" "github.com/fosrl/newt/holepunch"
@@ -221,7 +220,6 @@ func (o *Olm) handleSync(msg websocket.WSMessage) {
logger.Info("Sync: Adding new peer for site %d", siteId) logger.Info("Sync: Adding new peer for site %d", siteId)
o.holePunchManager.TriggerHolePunch() o.holePunchManager.TriggerHolePunch()
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
// // TODO: do we need to send the message to the cloud to add the peer that way? // // TODO: do we need to send the message to the cloud to add the peer that way?
// if err := o.peerManager.AddPeer(expectedSite); err != nil { // if err := o.peerManager.AddPeer(expectedSite); err != nil {
@@ -232,17 +230,9 @@ func (o *Olm) handleSync(msg websocket.WSMessage) {
// add the peer via the server // add the peer via the server
// this is important because newt needs to get triggered as well to add the peer once the hp is complete // this is important because newt needs to get triggered as well to add the peer once the hp is complete
chainId := fmt.Sprintf("sync-%d", expectedSite.SiteId) o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
o.peerSendMu.Lock()
if stop, ok := o.stopPeerSends[chainId]; ok {
stop()
}
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
"siteId": expectedSite.SiteId, "siteId": expectedSite.SiteId,
"chainId": chainId, }, 1*time.Second, 10)
}, 2*time.Second, 10)
o.stopPeerSends[chainId] = stopFunc
o.peerSendMu.Unlock()
} else { } else {
// Existing peer - check if update is needed // Existing peer - check if update is needed

View File

@@ -2,8 +2,6 @@ package olm
import ( import (
"context" "context"
"crypto/rand"
"encoding/hex"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@@ -67,9 +65,7 @@ type Olm struct {
stopRegister func() stopRegister func()
updateRegister func(newData any) updateRegister func(newData any)
stopPeerSends map[string]func() stopPeerSend func()
stopPeerInits map[string]func()
peerSendMu sync.Mutex
// WaitGroup to track tunnel lifecycle // WaitGroup to track tunnel lifecycle
tunnelWg sync.WaitGroup tunnelWg sync.WaitGroup
@@ -120,13 +116,6 @@ func (o *Olm) initTunnelInfo(clientID string) error {
return nil return nil
} }
// generateChainId generates a random chain ID for tracking peer sender lifecycles.
func generateChainId() string {
b := make([]byte, 8)
_, _ = rand.Read(b)
return hex.EncodeToString(b)
}
func Init(ctx context.Context, config OlmConfig) (*Olm, error) { func Init(ctx context.Context, config OlmConfig) (*Olm, error) {
logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel)) logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel))
@@ -181,8 +170,6 @@ func Init(ctx context.Context, config OlmConfig) (*Olm, error) {
olmCtx: ctx, olmCtx: ctx,
apiServer: apiServer, apiServer: apiServer,
olmConfig: config, olmConfig: config,
stopPeerSends: make(map[string]func()),
stopPeerInits: make(map[string]func()),
} }
newOlm.registerAPICallbacks() newOlm.registerAPICallbacks()
@@ -297,21 +284,6 @@ func (o *Olm) registerAPICallbacks() {
logger.Info("Processing power mode change request via API: mode=%s", req.Mode) logger.Info("Processing power mode change request via API: mode=%s", req.Mode)
return o.SetPowerMode(req.Mode) return o.SetPowerMode(req.Mode)
}, },
func(req api.JITConnectionRequest) error {
logger.Info("Processing JIT connect request via API: site=%s resource=%s", req.Site, req.Resource)
chainId := generateChainId()
o.peerSendMu.Lock()
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/init", map[string]interface{}{
"siteId": req.Site,
"resourceId": req.Resource,
"chainId": chainId,
}, 2*time.Second, 10)
o.stopPeerInits[chainId] = stopFunc
o.peerSendMu.Unlock()
return nil
},
) )
} }
@@ -406,7 +378,6 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
// Handler for peer handshake - adds exit node to holepunch rotation and notifies server // Handler for peer handshake - adds exit node to holepunch rotation and notifies server
o.websocket.RegisterHandler("olm/wg/peer/holepunch/site/add", o.handleWgPeerHolepunchAddSite) o.websocket.RegisterHandler("olm/wg/peer/holepunch/site/add", o.handleWgPeerHolepunchAddSite)
o.websocket.RegisterHandler("olm/wg/peer/chain/cancel", o.handleCancelChain)
o.websocket.RegisterHandler("olm/sync", o.handleSync) o.websocket.RegisterHandler("olm/sync", o.handleSync)
o.websocket.OnConnect(func() error { o.websocket.OnConnect(func() error {
@@ -449,7 +420,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
"userToken": userToken, "userToken": userToken,
"fingerprint": o.fingerprint, "fingerprint": o.fingerprint,
"postures": o.postures, "postures": o.postures,
}, 2*time.Second, 10) }, 1*time.Second, 10)
// Invoke onRegistered callback if configured // Invoke onRegistered callback if configured
if o.olmConfig.OnRegistered != nil { if o.olmConfig.OnRegistered != nil {
@@ -546,22 +517,6 @@ func (o *Olm) Close() {
o.stopRegister = nil o.stopRegister = nil
} }
// Stop all pending peer init and send senders before closing websocket
o.peerSendMu.Lock()
for _, stop := range o.stopPeerInits {
if stop != nil {
stop()
}
}
o.stopPeerInits = make(map[string]func())
for _, stop := range o.stopPeerSends {
if stop != nil {
stop()
}
}
o.stopPeerSends = make(map[string]func())
o.peerSendMu.Unlock()
// send a disconnect message to the cloud to show disconnected // send a disconnect message to the cloud to show disconnected
if o.websocket != nil { if o.websocket != nil {
o.websocket.SendMessage("olm/disconnecting", map[string]any{}) o.websocket.SendMessage("olm/disconnecting", map[string]any{})

View File

@@ -20,38 +20,31 @@ func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) {
return return
} }
if o.stopPeerSend != nil {
o.stopPeerSend()
o.stopPeerSend = nil
}
jsonData, err := json.Marshal(msg.Data) jsonData, err := json.Marshal(msg.Data)
if err != nil { if err != nil {
logger.Error("Error marshaling data: %v", err) logger.Error("Error marshaling data: %v", err)
return return
} }
var siteConfigMsg struct { var siteConfig peers.SiteConfig
peers.SiteConfig if err := json.Unmarshal(jsonData, &siteConfig); err != nil {
ChainId string `json:"chainId"`
}
if err := json.Unmarshal(jsonData, &siteConfigMsg); err != nil {
logger.Error("Error unmarshaling add data: %v", err) logger.Error("Error unmarshaling add data: %v", err)
return return
} }
if siteConfigMsg.ChainId != "" {
o.peerSendMu.Lock()
if stop, ok := o.stopPeerSends[siteConfigMsg.ChainId]; ok {
stop()
delete(o.stopPeerSends, siteConfigMsg.ChainId)
}
o.peerSendMu.Unlock()
}
_ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it _ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it
if err := o.peerManager.AddPeer(siteConfigMsg.SiteConfig); err != nil { if err := o.peerManager.AddPeer(siteConfig); err != nil {
logger.Error("Failed to add peer: %v", err) logger.Error("Failed to add peer: %v", err)
return return
} }
logger.Info("Successfully added peer for site %d", siteConfigMsg.SiteId) logger.Info("Successfully added peer for site %d", siteConfig.SiteId)
} }
func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) { func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) {
@@ -171,19 +164,12 @@ func (o *Olm) handleWgPeerRelay(msg websocket.WSMessage) {
return return
} }
var relayData struct { var relayData peers.RelayPeerData
peers.RelayPeerData
ChainId string `json:"chainId"`
}
if err := json.Unmarshal(jsonData, &relayData); err != nil { if err := json.Unmarshal(jsonData, &relayData); err != nil {
logger.Error("Error unmarshaling relay data: %v", err) logger.Error("Error unmarshaling relay data: %v", err)
return return
} }
if monitor := o.peerManager.GetPeerMonitor(); monitor != nil {
monitor.CancelRelaySend(relayData.ChainId)
}
primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint) primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint)
if err != nil { if err != nil {
logger.Error("Failed to resolve primary relay endpoint: %v", err) logger.Error("Failed to resolve primary relay endpoint: %v", err)
@@ -211,19 +197,12 @@ func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) {
return return
} }
var relayData struct { var relayData peers.UnRelayPeerData
peers.UnRelayPeerData
ChainId string `json:"chainId"`
}
if err := json.Unmarshal(jsonData, &relayData); err != nil { if err := json.Unmarshal(jsonData, &relayData); err != nil {
logger.Error("Error unmarshaling relay data: %v", err) logger.Error("Error unmarshaling relay data: %v", err)
return return
} }
if monitor := o.peerManager.GetPeerMonitor(); monitor != nil {
monitor.CancelRelaySend(relayData.ChainId)
}
primaryRelay, err := util.ResolveDomain(relayData.Endpoint) primaryRelay, err := util.ResolveDomain(relayData.Endpoint)
if err != nil { if err != nil {
logger.Warn("Failed to resolve primary relay endpoint: %v", err) logger.Warn("Failed to resolve primary relay endpoint: %v", err)
@@ -252,7 +231,6 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
var handshakeData struct { var handshakeData struct {
SiteId int `json:"siteId"` SiteId int `json:"siteId"`
ChainId string `json:"chainId"`
ExitNode struct { ExitNode struct {
PublicKey string `json:"publicKey"` PublicKey string `json:"publicKey"`
Endpoint string `json:"endpoint"` Endpoint string `json:"endpoint"`
@@ -265,16 +243,6 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
return return
} }
// Stop the peer init sender for this chain, if any
if handshakeData.ChainId != "" {
o.peerSendMu.Lock()
if stop, ok := o.stopPeerInits[handshakeData.ChainId]; ok {
stop()
delete(o.stopPeerInits, handshakeData.ChainId)
}
o.peerSendMu.Unlock()
}
// Get existing peer from PeerManager // Get existing peer from PeerManager
_, exists := o.peerManager.GetPeer(handshakeData.SiteId) _, exists := o.peerManager.GetPeer(handshakeData.SiteId)
if exists { if exists {
@@ -305,64 +273,10 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
// Send handshake acknowledgment back to server with retry, keyed by chainId // Send handshake acknowledgment back to server with retry
chainId := handshakeData.ChainId o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
if chainId == "" {
chainId = generateChainId()
}
o.peerSendMu.Lock()
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
"siteId": handshakeData.SiteId, "siteId": handshakeData.SiteId,
"chainId": chainId, }, 1*time.Second, 10)
}, 2*time.Second, 10)
o.stopPeerSends[chainId] = stopFunc
o.peerSendMu.Unlock()
logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint) logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint)
} }
func (o *Olm) handleCancelChain(msg websocket.WSMessage) {
logger.Debug("Received cancel-chain message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling cancel-chain data: %v", err)
return
}
var cancelData struct {
ChainId string `json:"chainId"`
}
if err := json.Unmarshal(jsonData, &cancelData); err != nil {
logger.Error("Error unmarshaling cancel-chain data: %v", err)
return
}
if cancelData.ChainId == "" {
logger.Warn("Received cancel-chain message with no chainId")
return
}
o.peerSendMu.Lock()
defer o.peerSendMu.Unlock()
found := false
if stop, ok := o.stopPeerInits[cancelData.ChainId]; ok {
stop()
delete(o.stopPeerInits, cancelData.ChainId)
found = true
}
if stop, ok := o.stopPeerSends[cancelData.ChainId]; ok {
stop()
delete(o.stopPeerSends, cancelData.ChainId)
found = true
}
if found {
logger.Info("Cancelled chain %s", cancelData.ChainId)
} else {
logger.Warn("Cancel-chain: no active sender found for chain %s", cancelData.ChainId)
}
}

View File

@@ -2,8 +2,6 @@ package monitor
import ( import (
"context" "context"
"crypto/rand"
"encoding/hex"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@@ -37,10 +35,6 @@ type PeerMonitor struct {
maxAttempts int maxAttempts int
wsClient *websocket.Client wsClient *websocket.Client
// Relay sender tracking
relaySends map[string]func()
relaySendMu sync.Mutex
// Netstack fields // Netstack fields
middleDev *middleDevice.MiddleDevice middleDev *middleDevice.MiddleDevice
localIP string localIP string
@@ -88,12 +82,6 @@ type PeerMonitor struct {
} }
// NewPeerMonitor creates a new peer monitor with the given callback // NewPeerMonitor creates a new peer monitor with the given callback
func generateChainId() string {
b := make([]byte, 8)
_, _ = rand.Read(b)
return hex.EncodeToString(b)
}
func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind, apiServer *api.API) *PeerMonitor { func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind, apiServer *api.API) *PeerMonitor {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
pm := &PeerMonitor{ pm := &PeerMonitor{
@@ -111,7 +99,6 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
holepunchEndpoints: make(map[int]string), holepunchEndpoints: make(map[int]string),
holepunchStatus: make(map[int]bool), holepunchStatus: make(map[int]bool),
relayedPeers: make(map[int]bool), relayedPeers: make(map[int]bool),
relaySends: make(map[string]func()),
holepunchMaxAttempts: 3, // Trigger relay after 3 consecutive failures holepunchMaxAttempts: 3, // Trigger relay after 3 consecutive failures
holepunchFailures: make(map[int]int), holepunchFailures: make(map[int]int),
// Rapid initial test settings: complete within ~1.5 seconds // Rapid initial test settings: complete within ~1.5 seconds
@@ -409,23 +396,20 @@ func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status Connectio
} }
} }
// sendRelay sends a relay message to the server with retry, keyed by chainId // sendRelay sends a relay message to the server
func (pm *PeerMonitor) sendRelay(siteID int) error { func (pm *PeerMonitor) sendRelay(siteID int) error {
if pm.wsClient == nil { if pm.wsClient == nil {
return fmt.Errorf("websocket client is nil") return fmt.Errorf("websocket client is nil")
} }
chainId := generateChainId() err := pm.wsClient.SendMessage("olm/wg/relay", map[string]interface{}{
stopFunc, _ := pm.wsClient.SendMessageInterval("olm/wg/relay", map[string]interface{}{
"siteId": siteID, "siteId": siteID,
"chainId": chainId, })
}, 2*time.Second, 10) if err != nil {
logger.Error("Failed to send registration message: %v", err)
pm.relaySendMu.Lock() return err
pm.relaySends[chainId] = stopFunc }
pm.relaySendMu.Unlock() logger.Info("Sent relay message")
logger.Info("Sent relay message for site %d (chain %s)", siteID, chainId)
return nil return nil
} }
@@ -435,52 +419,23 @@ func (pm *PeerMonitor) RequestRelay(siteID int) error {
return pm.sendRelay(siteID) return pm.sendRelay(siteID)
} }
// sendUnRelay sends an unrelay message to the server with retry, keyed by chainId // sendUnRelay sends an unrelay message to the server
func (pm *PeerMonitor) sendUnRelay(siteID int) error { func (pm *PeerMonitor) sendUnRelay(siteID int) error {
if pm.wsClient == nil { if pm.wsClient == nil {
return fmt.Errorf("websocket client is nil") return fmt.Errorf("websocket client is nil")
} }
chainId := generateChainId() err := pm.wsClient.SendMessage("olm/wg/unrelay", map[string]interface{}{
stopFunc, _ := pm.wsClient.SendMessageInterval("olm/wg/unrelay", map[string]interface{}{
"siteId": siteID, "siteId": siteID,
"chainId": chainId, })
}, 2*time.Second, 10) if err != nil {
logger.Error("Failed to send registration message: %v", err)
pm.relaySendMu.Lock() return err
pm.relaySends[chainId] = stopFunc }
pm.relaySendMu.Unlock() logger.Info("Sent unrelay message")
logger.Info("Sent unrelay message for site %d (chain %s)", siteID, chainId)
return nil return nil
} }
// CancelRelaySend stops the interval sender for the given chainId, if one exists.
// If chainId is empty, all active relay senders are stopped.
func (pm *PeerMonitor) CancelRelaySend(chainId string) {
pm.relaySendMu.Lock()
defer pm.relaySendMu.Unlock()
if chainId == "" {
for id, stop := range pm.relaySends {
if stop != nil {
stop()
}
delete(pm.relaySends, id)
}
logger.Info("Cancelled all relay senders")
return
}
if stop, ok := pm.relaySends[chainId]; ok {
stop()
delete(pm.relaySends, chainId)
logger.Info("Cancelled relay sender for chain %s", chainId)
} else {
logger.Warn("CancelRelaySend: no active sender for chain %s", chainId)
}
}
// Stop stops monitoring all peers // Stop stops monitoring all peers
func (pm *PeerMonitor) Stop() { func (pm *PeerMonitor) Stop() {
// Stop holepunch monitor first (outside of mutex to avoid deadlock) // Stop holepunch monitor first (outside of mutex to avoid deadlock)
@@ -722,16 +677,6 @@ func (pm *PeerMonitor) Close() {
// Stop holepunch monitor first (outside of mutex to avoid deadlock) // Stop holepunch monitor first (outside of mutex to avoid deadlock)
pm.stopHolepunchMonitor() pm.stopHolepunchMonitor()
// Stop all pending relay senders
pm.relaySendMu.Lock()
for chainId, stop := range pm.relaySends {
if stop != nil {
stop()
}
delete(pm.relaySends, chainId)
}
pm.relaySendMu.Unlock()
pm.mutex.Lock() pm.mutex.Lock()
defer pm.mutex.Unlock() defer pm.mutex.Unlock()