Compare commits

..

9 Commits
dev ... jit

Author SHA1 Message Date
Owen
c2b5ef96a4 Jit of aliases working 2026-03-12 17:26:46 -07:00
Owen
e326da3d3e Merge branch 'dev' into jit 2026-03-12 16:53:16 -07:00
Owen
22cd02ae15 Alias jit handler 2026-03-11 15:56:51 -07:00
Owen
e2690bcc03 Store site id 2026-03-06 16:19:00 -08:00
Owen
f2d0e6a14c Merge branch 'dev' into jit 2026-03-06 16:08:24 -08:00
Owen
809dbe77de Make chainId in relay message bckwd compat 2026-03-06 15:27:03 -08:00
Owen
c67c2a60a1 Handle canceling sends for relay 2026-03-06 15:15:31 -08:00
Owen
051c0fdfd8 Working jit with chain ids 2026-03-04 17:51:48 -08:00
Owen
e7507e0837 Add api endpoints to jit 2026-03-04 17:01:17 -08:00
11 changed files with 476 additions and 109 deletions

View File

@@ -78,6 +78,13 @@ type MetadataChangeRequest struct {
Postures map[string]any `json:"postures"` Postures map[string]any `json:"postures"`
} }
// JITConnectionRequest defines the structure for a dynamic Just-In-Time connection request.
// Either SiteID or ResourceID must be provided (but not necessarily both).
type JITConnectionRequest struct {
Site string `json:"site,omitempty"`
Resource string `json:"resource,omitempty"`
}
// API represents the HTTP server and its state // API represents the HTTP server and its state
type API struct { type API struct {
addr string addr string
@@ -92,6 +99,7 @@ type API struct {
onExit func() error onExit func() error
onRebind func() error onRebind func() error
onPowerMode func(PowerModeRequest) error onPowerMode func(PowerModeRequest) error
onJITConnect func(JITConnectionRequest) error
statusMu sync.RWMutex statusMu sync.RWMutex
peerStatuses map[int]*PeerStatus peerStatuses map[int]*PeerStatus
@@ -143,6 +151,7 @@ func (s *API) SetHandlers(
onExit func() error, onExit func() error,
onRebind func() error, onRebind func() error,
onPowerMode func(PowerModeRequest) error, onPowerMode func(PowerModeRequest) error,
onJITConnect func(JITConnectionRequest) error,
) { ) {
s.onConnect = onConnect s.onConnect = onConnect
s.onSwitchOrg = onSwitchOrg s.onSwitchOrg = onSwitchOrg
@@ -151,6 +160,7 @@ func (s *API) SetHandlers(
s.onExit = onExit s.onExit = onExit
s.onRebind = onRebind s.onRebind = onRebind
s.onPowerMode = onPowerMode s.onPowerMode = onPowerMode
s.onJITConnect = onJITConnect
} }
// Start starts the HTTP server // Start starts the HTTP server
@@ -169,6 +179,7 @@ func (s *API) Start() error {
mux.HandleFunc("/health", s.handleHealth) mux.HandleFunc("/health", s.handleHealth)
mux.HandleFunc("/rebind", s.handleRebind) mux.HandleFunc("/rebind", s.handleRebind)
mux.HandleFunc("/power-mode", s.handlePowerMode) mux.HandleFunc("/power-mode", s.handlePowerMode)
mux.HandleFunc("/jit-connect", s.handleJITConnect)
s.server = &http.Server{ s.server = &http.Server{
Handler: mux, Handler: mux,
@@ -633,6 +644,54 @@ func (s *API) handleRebind(w http.ResponseWriter, r *http.Request) {
}) })
} }
// handleJITConnect handles the /jit-connect endpoint.
// It initiates a dynamic Just-In-Time connection to a site identified by either
// a site or a resource. Exactly one of the two must be provided.
func (s *API) handleJITConnect(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req JITConnectionRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest)
return
}
// Validate that exactly one of site or resource is provided
if req.Site == "" && req.Resource == "" {
http.Error(w, "Missing required field: either site or resource must be provided", http.StatusBadRequest)
return
}
if req.Site != "" && req.Resource != "" {
http.Error(w, "Ambiguous request: provide either site or resource, not both", http.StatusBadRequest)
return
}
if req.Site != "" {
logger.Info("Received JIT connection request via API: site=%s", req.Site)
} else {
logger.Info("Received JIT connection request via API: resource=%s", req.Resource)
}
if s.onJITConnect != nil {
if err := s.onJITConnect(req); err != nil {
http.Error(w, fmt.Sprintf("JIT connection failed: %v", err), http.StatusInternalServerError)
return
}
} else {
http.Error(w, "JIT connect handler not configured", http.StatusNotImplemented)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusAccepted)
_ = json.NewEncoder(w).Encode(map[string]string{
"status": "JIT connection request accepted",
})
}
// handlePowerMode handles the /power-mode endpoint // handlePowerMode handles the /power-mode endpoint
// This allows changing the power mode between "normal" and "low" // This allows changing the power mode between "normal" and "low"
func (s *API) handlePowerMode(w http.ResponseWriter, r *http.Request) { func (s *API) handlePowerMode(w http.ResponseWriter, r *http.Request) {

View File

@@ -45,6 +45,11 @@ type DNSProxy struct {
tunnelActivePorts map[uint16]bool tunnelActivePorts map[uint16]bool
tunnelPortsLock sync.Mutex tunnelPortsLock sync.Mutex
// jitHandler is called when a local record is resolved for a site that may not be
// connected yet, giving the caller a chance to initiate a JIT connection.
// It is invoked asynchronously so it never blocks DNS resolution.
jitHandler func(siteId int)
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
wg sync.WaitGroup wg sync.WaitGroup
@@ -384,6 +389,16 @@ func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clie
response = p.checkLocalRecords(msg, question) response = p.checkLocalRecords(msg, question)
} }
// If a local A/AAAA record was found, notify the JIT handler so that the owning
// site can be connected on-demand if it is not yet active.
if response != nil && p.jitHandler != nil &&
(question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA) {
if siteId, ok := p.recordStore.GetSiteIdForDomain(question.Name); ok && siteId != 0 {
handler := p.jitHandler
go handler(siteId)
}
}
// If no local records, forward to upstream // If no local records, forward to upstream
if response == nil { if response == nil {
logger.Debug("No local record for %s, forwarding upstream to %v", question.Name, p.upstreamDNS) logger.Debug("No local record for %s, forwarding upstream to %v", question.Name, p.upstreamDNS)
@@ -718,11 +733,20 @@ func (p *DNSProxy) runPacketSender() {
} }
} }
// SetJITHandler registers a callback that is invoked whenever a local DNS record is
// resolved for an A or AAAA query. The siteId identifies which site owns the record.
// The handler is called in its own goroutine so it must be safe to call concurrently.
// Pass nil to disable JIT notifications.
func (p *DNSProxy) SetJITHandler(handler func(siteId int)) {
p.jitHandler = handler
}
// AddDNSRecord adds a DNS record to the local store // AddDNSRecord adds a DNS record to the local store
// domain should be a domain name (e.g., "example.com" or "example.com.") // domain should be a domain name (e.g., "example.com" or "example.com.")
// ip should be a valid IPv4 or IPv6 address // ip should be a valid IPv4 or IPv6 address
func (p *DNSProxy) AddDNSRecord(domain string, ip net.IP) error { func (p *DNSProxy) AddDNSRecord(domain string, ip net.IP, siteId int) error {
return p.recordStore.AddRecord(domain, ip) logger.Debug("Adding dns record for domain %s with IP %s (siteId=%d)", domain, ip.String(), siteId)
return p.recordStore.AddRecord(domain, ip, siteId)
} }
// RemoveDNSRecord removes a DNS record from the local store // RemoveDNSRecord removes a DNS record from the local store

View File

@@ -14,7 +14,7 @@ func TestCheckLocalRecordsNODATAForAAAA(t *testing.T) {
// Add an A record for a domain (no AAAA record) // Add an A record for a domain (no AAAA record)
ip := net.ParseIP("10.0.0.1") ip := net.ParseIP("10.0.0.1")
err := proxy.recordStore.AddRecord("myservice.internal", ip) err := proxy.recordStore.AddRecord("myservice.internal", ip, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add A record: %v", err) t.Fatalf("Failed to add A record: %v", err)
} }
@@ -64,7 +64,7 @@ func TestCheckLocalRecordsNODATAForA(t *testing.T) {
// Add an AAAA record for a domain (no A record) // Add an AAAA record for a domain (no A record)
ip := net.ParseIP("2001:db8::1") ip := net.ParseIP("2001:db8::1")
err := proxy.recordStore.AddRecord("ipv6only.internal", ip) err := proxy.recordStore.AddRecord("ipv6only.internal", ip, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add AAAA record: %v", err) t.Fatalf("Failed to add AAAA record: %v", err)
} }
@@ -113,7 +113,7 @@ func TestCheckLocalRecordsNonExistentDomain(t *testing.T) {
} }
// Add a record so the store isn't empty // Add a record so the store isn't empty
err := proxy.recordStore.AddRecord("exists.internal", net.ParseIP("10.0.0.1")) err := proxy.recordStore.AddRecord("exists.internal", net.ParseIP("10.0.0.1"), 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add record: %v", err) t.Fatalf("Failed to add record: %v", err)
} }
@@ -144,7 +144,7 @@ func TestCheckLocalRecordsNODATAWildcard(t *testing.T) {
// Add a wildcard A record (no AAAA) // Add a wildcard A record (no AAAA)
ip := net.ParseIP("10.0.0.1") ip := net.ParseIP("10.0.0.1")
err := proxy.recordStore.AddRecord("*.wildcard.internal", ip) err := proxy.recordStore.AddRecord("*.wildcard.internal", ip, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add wildcard A record: %v", err) t.Fatalf("Failed to add wildcard A record: %v", err)
} }

View File

@@ -20,8 +20,9 @@ const (
// recordSet holds A and AAAA records for a single domain or wildcard pattern // recordSet holds A and AAAA records for a single domain or wildcard pattern
type recordSet struct { type recordSet struct {
A []net.IP A []net.IP
AAAA []net.IP AAAA []net.IP
SiteId int
} }
// DNSRecordStore manages local DNS records for A, AAAA, and PTR queries. // DNSRecordStore manages local DNS records for A, AAAA, and PTR queries.
@@ -46,8 +47,9 @@ func NewDNSRecordStore() *DNSRecordStore {
// domain should be in FQDN format (e.g., "example.com.") // domain should be in FQDN format (e.g., "example.com.")
// domain can contain wildcards: * (0+ chars) and ? (exactly 1 char) // domain can contain wildcards: * (0+ chars) and ? (exactly 1 char)
// ip should be a valid IPv4 or IPv6 address // ip should be a valid IPv4 or IPv6 address
// siteId is the site that owns this alias/domain
// Automatically adds a corresponding PTR record for non-wildcard domains // Automatically adds a corresponding PTR record for non-wildcard domains
func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error { func (s *DNSRecordStore) AddRecord(domain string, ip net.IP, siteId int) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@@ -69,12 +71,22 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
} }
if m[domain] == nil { if m[domain] == nil {
m[domain] = &recordSet{} m[domain] = &recordSet{SiteId: siteId}
} }
rs := m[domain] rs := m[domain]
if isV4 { if isV4 {
for _, existing := range rs.A {
if existing.Equal(ip) {
return nil
}
}
rs.A = append(rs.A, ip) rs.A = append(rs.A, ip)
} else { } else {
for _, existing := range rs.AAAA {
if existing.Equal(ip) {
return nil
}
}
rs.AAAA = append(rs.AAAA, ip) rs.AAAA = append(rs.AAAA, ip)
} }
@@ -85,6 +97,7 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
return nil return nil
} }
// AddPTRRecord adds a PTR record mapping an IP address to a domain name // AddPTRRecord adds a PTR record mapping an IP address to a domain name
// ip should be a valid IPv4 or IPv6 address // ip should be a valid IPv4 or IPv6 address
// domain should be in FQDN format (e.g., "example.com.") // domain should be in FQDN format (e.g., "example.com.")
@@ -179,6 +192,30 @@ func (s *DNSRecordStore) RemovePTRRecord(ip net.IP) {
delete(s.ptrRecords, ip.String()) delete(s.ptrRecords, ip.String())
} }
// GetSiteIdForDomain returns the siteId associated with the given domain.
// It checks exact matches first, then wildcard patterns.
// The second return value is false if the domain is not found in local records.
func (s *DNSRecordStore) GetSiteIdForDomain(domain string) (int, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
domain = strings.ToLower(dns.Fqdn(domain))
// Check exact match first
if rs, exists := s.exact[domain]; exists {
return rs.SiteId, true
}
// Check wildcard matches
for pattern, rs := range s.wildcards {
if matchWildcard(pattern, domain) {
return rs.SiteId, true
}
}
return 0, false
}
// GetRecords returns all IP addresses for a domain and record type. // GetRecords returns all IP addresses for a domain and record type.
// The second return value indicates whether the domain exists at all // The second return value indicates whether the domain exists at all
// (true = domain exists, use NODATA if no records; false = NXDOMAIN). // (true = domain exists, use NODATA if no records; false = NXDOMAIN).

View File

@@ -170,14 +170,14 @@ func TestDNSRecordStoreWildcard(t *testing.T) {
// Add wildcard records // Add wildcard records
wildcardIP := net.ParseIP("10.0.0.1") wildcardIP := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.autoco.internal", wildcardIP) err := store.AddRecord("*.autoco.internal", wildcardIP, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err) t.Fatalf("Failed to add wildcard record: %v", err)
} }
// Add exact record // Add exact record
exactIP := net.ParseIP("10.0.0.2") exactIP := net.ParseIP("10.0.0.2")
err = store.AddRecord("exact.autoco.internal", exactIP) err = store.AddRecord("exact.autoco.internal", exactIP, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add exact record: %v", err) t.Fatalf("Failed to add exact record: %v", err)
} }
@@ -221,7 +221,7 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) {
// Add complex wildcard pattern // Add complex wildcard pattern
ip1 := net.ParseIP("10.0.0.1") ip1 := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.host-0?.autoco.internal", ip1) err := store.AddRecord("*.host-0?.autoco.internal", ip1, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err) t.Fatalf("Failed to add wildcard record: %v", err)
} }
@@ -262,7 +262,7 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) {
// Add wildcard record // Add wildcard record
ip := net.ParseIP("10.0.0.1") ip := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.autoco.internal", ip) err := store.AddRecord("*.autoco.internal", ip, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err) t.Fatalf("Failed to add wildcard record: %v", err)
} }
@@ -297,18 +297,18 @@ func TestDNSRecordStoreMultipleWildcards(t *testing.T) {
ip2 := net.ParseIP("10.0.0.2") ip2 := net.ParseIP("10.0.0.2")
ip3 := net.ParseIP("10.0.0.3") ip3 := net.ParseIP("10.0.0.3")
err := store.AddRecord("*.prod.autoco.internal", ip1) err := store.AddRecord("*.prod.autoco.internal", ip1, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add first wildcard: %v", err) t.Fatalf("Failed to add first wildcard: %v", err)
} }
err = store.AddRecord("*.dev.autoco.internal", ip2) err = store.AddRecord("*.dev.autoco.internal", ip2, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add second wildcard: %v", err) t.Fatalf("Failed to add second wildcard: %v", err)
} }
// Add a broader wildcard that matches both // Add a broader wildcard that matches both
err = store.AddRecord("*.autoco.internal", ip3) err = store.AddRecord("*.autoco.internal", ip3, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add third wildcard: %v", err) t.Fatalf("Failed to add third wildcard: %v", err)
} }
@@ -337,7 +337,7 @@ func TestDNSRecordStoreIPv6Wildcard(t *testing.T) {
// Add IPv6 wildcard record // Add IPv6 wildcard record
ip := net.ParseIP("2001:db8::1") ip := net.ParseIP("2001:db8::1")
err := store.AddRecord("*.autoco.internal", ip) err := store.AddRecord("*.autoco.internal", ip, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add IPv6 wildcard record: %v", err) t.Fatalf("Failed to add IPv6 wildcard record: %v", err)
} }
@@ -357,7 +357,7 @@ func TestHasRecordWildcard(t *testing.T) {
// Add wildcard record // Add wildcard record
ip := net.ParseIP("10.0.0.1") ip := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.autoco.internal", ip) err := store.AddRecord("*.autoco.internal", ip, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err) t.Fatalf("Failed to add wildcard record: %v", err)
} }
@@ -378,7 +378,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
// Add record with mixed case // Add record with mixed case
ip := net.ParseIP("10.0.0.1") ip := net.ParseIP("10.0.0.1")
err := store.AddRecord("MyHost.AutoCo.Internal", ip) err := store.AddRecord("MyHost.AutoCo.Internal", ip, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add mixed case record: %v", err) t.Fatalf("Failed to add mixed case record: %v", err)
} }
@@ -403,7 +403,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
// Test wildcard with mixed case // Test wildcard with mixed case
wildcardIP := net.ParseIP("10.0.0.2") wildcardIP := net.ParseIP("10.0.0.2")
err = store.AddRecord("*.Example.Com", wildcardIP) err = store.AddRecord("*.Example.Com", wildcardIP, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add mixed case wildcard: %v", err) t.Fatalf("Failed to add mixed case wildcard: %v", err)
} }
@@ -689,7 +689,7 @@ func TestClearPTRRecords(t *testing.T) {
store.AddPTRRecord(ip2, "host2.example.com.") store.AddPTRRecord(ip2, "host2.example.com.")
// Add some A records too // Add some A records too
store.AddRecord("test.example.com.", net.ParseIP("10.0.0.1")) store.AddRecord("test.example.com.", net.ParseIP("10.0.0.1"), 0)
// Verify PTR records exist // Verify PTR records exist
if !store.HasPTRRecord("1.1.168.192.in-addr.arpa.") { if !store.HasPTRRecord("1.1.168.192.in-addr.arpa.") {
@@ -719,7 +719,7 @@ func TestAutomaticPTRRecordOnAdd(t *testing.T) {
// Add an A record - should automatically add PTR record // Add an A record - should automatically add PTR record
domain := "host.example.com." domain := "host.example.com."
ip := net.ParseIP("192.168.1.100") ip := net.ParseIP("192.168.1.100")
err := store.AddRecord(domain, ip) err := store.AddRecord(domain, ip, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add A record: %v", err) t.Fatalf("Failed to add A record: %v", err)
} }
@@ -737,7 +737,7 @@ func TestAutomaticPTRRecordOnAdd(t *testing.T) {
// Add AAAA record - should also automatically add PTR record // Add AAAA record - should also automatically add PTR record
domain6 := "ipv6host.example.com." domain6 := "ipv6host.example.com."
ip6 := net.ParseIP("2001:db8::1") ip6 := net.ParseIP("2001:db8::1")
err = store.AddRecord(domain6, ip6) err = store.AddRecord(domain6, ip6, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add AAAA record: %v", err) t.Fatalf("Failed to add AAAA record: %v", err)
} }
@@ -759,7 +759,7 @@ func TestAutomaticPTRRecordOnRemove(t *testing.T) {
// Add an A record (with automatic PTR) // Add an A record (with automatic PTR)
domain := "host.example.com." domain := "host.example.com."
ip := net.ParseIP("192.168.1.100") ip := net.ParseIP("192.168.1.100")
store.AddRecord(domain, ip) store.AddRecord(domain, ip, 0)
// Verify PTR exists // Verify PTR exists
reverseDomain := "100.1.168.192.in-addr.arpa." reverseDomain := "100.1.168.192.in-addr.arpa."
@@ -789,8 +789,8 @@ func TestAutomaticPTRRecordOnRemoveAll(t *testing.T) {
domain := "host.example.com." domain := "host.example.com."
ip1 := net.ParseIP("192.168.1.100") ip1 := net.ParseIP("192.168.1.100")
ip2 := net.ParseIP("192.168.1.101") ip2 := net.ParseIP("192.168.1.101")
store.AddRecord(domain, ip1) store.AddRecord(domain, ip1, 0)
store.AddRecord(domain, ip2) store.AddRecord(domain, ip2, 0)
// Verify both PTR records exist // Verify both PTR records exist
reverseDomain1 := "100.1.168.192.in-addr.arpa." reverseDomain1 := "100.1.168.192.in-addr.arpa."
@@ -820,7 +820,7 @@ func TestNoPTRForWildcardRecords(t *testing.T) {
// Add wildcard record - should NOT create PTR record // Add wildcard record - should NOT create PTR record
domain := "*.example.com." domain := "*.example.com."
ip := net.ParseIP("192.168.1.100") ip := net.ParseIP("192.168.1.100")
err := store.AddRecord(domain, ip) err := store.AddRecord(domain, ip, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err) t.Fatalf("Failed to add wildcard record: %v", err)
} }
@@ -844,7 +844,7 @@ func TestPTRRecordOverwrite(t *testing.T) {
// Add first domain with IP // Add first domain with IP
domain1 := "host1.example.com." domain1 := "host1.example.com."
ip := net.ParseIP("192.168.1.100") ip := net.ParseIP("192.168.1.100")
store.AddRecord(domain1, ip) store.AddRecord(domain1, ip, 0)
// Verify PTR points to first domain // Verify PTR points to first domain
reverseDomain := "100.1.168.192.in-addr.arpa." reverseDomain := "100.1.168.192.in-addr.arpa."
@@ -858,7 +858,7 @@ func TestPTRRecordOverwrite(t *testing.T) {
// Add second domain with same IP - should overwrite PTR // Add second domain with same IP - should overwrite PTR
domain2 := "host2.example.com." domain2 := "host2.example.com."
store.AddRecord(domain2, ip) store.AddRecord(domain2, ip, 0)
// Verify PTR now points to second domain (last one added) // Verify PTR now points to second domain (last one added)
result, ok = store.GetPTRRecord(reverseDomain) result, ok = store.GetPTRRecord(reverseDomain)

View File

@@ -7,6 +7,7 @@ import (
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/fosrl/newt/logger" "github.com/fosrl/newt/logger"
"github.com/fosrl/newt/network" "github.com/fosrl/newt/network"
@@ -174,21 +175,19 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) {
for i := range wgData.Sites { for i := range wgData.Sites {
site := wgData.Sites[i] site := wgData.Sites[i]
if site.PublicKey == "" { if site.PublicKey != "" {
logger.Warn("Skipping site %d (%s): no public key available (site may not be connected)", site.SiteId, site.Name) var siteEndpoint string
continue // here we are going to take the relay endpoint if it exists which means we requested a relay for this peer
if site.RelayEndpoint != "" {
siteEndpoint = site.RelayEndpoint
} else {
siteEndpoint = site.Endpoint
}
o.apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false)
} }
var siteEndpoint string // we still call this to add the aliases for jit lookup but we just do that then pass inside. need to skip the above so we dont add to the api
// here we are going to take the relay endpoint if it exists which means we requested a relay for this peer
if site.RelayEndpoint != "" {
siteEndpoint = site.RelayEndpoint
} else {
siteEndpoint = site.Endpoint
}
o.apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false)
if err := o.peerManager.AddPeer(site); err != nil { if err := o.peerManager.AddPeer(site); err != nil {
logger.Error("Failed to add peer: %v", err) logger.Error("Failed to add peer: %v", err)
return return
@@ -203,6 +202,36 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) {
logger.Error("Failed to start DNS proxy: %v", err) logger.Error("Failed to start DNS proxy: %v", err)
} }
// Register JIT handler: when the DNS proxy resolves a local record, check whether
// the owning site is already connected and, if not, initiate a JIT connection.
o.dnsProxy.SetJITHandler(func(siteId int) {
if o.peerManager == nil || o.websocket == nil {
return
}
// Site already has an active peer connection - nothing to do.
if _, exists := o.peerManager.GetPeer(siteId); exists {
return
}
o.peerSendMu.Lock()
defer o.peerSendMu.Unlock()
// A JIT request for this site is already in-flight - avoid duplicate sends.
if _, pending := o.jitPendingSites[siteId]; pending {
return
}
chainId := generateChainId()
logger.Info("DNS-triggered JIT connect for site %d (chainId=%s)", siteId, chainId)
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/init", map[string]interface{}{
"siteId": siteId,
"chainId": chainId,
}, 2*time.Second, 10)
o.stopPeerInits[chainId] = stopFunc
o.jitPendingSites[siteId] = chainId
})
if o.tunnelConfig.OverrideDNS { if o.tunnelConfig.OverrideDNS {
// Set up DNS override to use our DNS proxy // Set up DNS override to use our DNS proxy
if err := dnsOverride.SetupDNSOverride(o.tunnelConfig.InterfaceName, o.dnsProxy.GetProxyIP()); err != nil { if err := dnsOverride.SetupDNSOverride(o.tunnelConfig.InterfaceName, o.dnsProxy.GetProxyIP()); err != nil {
@@ -280,12 +309,12 @@ func (o *Olm) handleTerminate(msg websocket.WSMessage) {
logger.Error("Error unmarshaling terminate error data: %v", err) logger.Error("Error unmarshaling terminate error data: %v", err)
} else { } else {
logger.Info("Terminate reason (code: %s): %s", errorData.Code, errorData.Message) logger.Info("Terminate reason (code: %s): %s", errorData.Code, errorData.Message)
if errorData.Code == "TERMINATED_INACTIVITY" { if errorData.Code == "TERMINATED_INACTIVITY" {
logger.Info("Ignoring...") logger.Info("Ignoring...")
return return
} }
// Set the olm error in the API server so it can be exposed via status // Set the olm error in the API server so it can be exposed via status
o.apiServer.SetOlmError(errorData.Code, errorData.Message) o.apiServer.SetOlmError(errorData.Code, errorData.Message)
} }

View File

@@ -2,6 +2,7 @@ package olm
import ( import (
"encoding/json" "encoding/json"
"fmt"
"time" "time"
"github.com/fosrl/newt/holepunch" "github.com/fosrl/newt/holepunch"
@@ -220,6 +221,7 @@ func (o *Olm) handleSync(msg websocket.WSMessage) {
logger.Info("Sync: Adding new peer for site %d", siteId) logger.Info("Sync: Adding new peer for site %d", siteId)
o.holePunchManager.TriggerHolePunch() o.holePunchManager.TriggerHolePunch()
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
// // TODO: do we need to send the message to the cloud to add the peer that way? // // TODO: do we need to send the message to the cloud to add the peer that way?
// if err := o.peerManager.AddPeer(expectedSite); err != nil { // if err := o.peerManager.AddPeer(expectedSite); err != nil {
@@ -230,9 +232,17 @@ func (o *Olm) handleSync(msg websocket.WSMessage) {
// add the peer via the server // add the peer via the server
// this is important because newt needs to get triggered as well to add the peer once the hp is complete // this is important because newt needs to get triggered as well to add the peer once the hp is complete
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ chainId := fmt.Sprintf("sync-%d", expectedSite.SiteId)
"siteId": expectedSite.SiteId, o.peerSendMu.Lock()
}, 1*time.Second, 10) if stop, ok := o.stopPeerSends[chainId]; ok {
stop()
}
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
"siteId": expectedSite.SiteId,
"chainId": chainId,
}, 2*time.Second, 10)
o.stopPeerSends[chainId] = stopFunc
o.peerSendMu.Unlock()
} else { } else {
// Existing peer - check if update is needed // Existing peer - check if update is needed

View File

@@ -2,6 +2,8 @@ package olm
import ( import (
"context" "context"
"crypto/rand"
"encoding/hex"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@@ -65,7 +67,10 @@ type Olm struct {
stopRegister func() stopRegister func()
updateRegister func(newData any) updateRegister func(newData any)
stopPeerSend func() stopPeerSends map[string]func()
stopPeerInits map[string]func()
jitPendingSites map[int]string // siteId -> chainId for in-flight JIT requests
peerSendMu sync.Mutex
// WaitGroup to track tunnel lifecycle // WaitGroup to track tunnel lifecycle
tunnelWg sync.WaitGroup tunnelWg sync.WaitGroup
@@ -116,6 +121,13 @@ func (o *Olm) initTunnelInfo(clientID string) error {
return nil return nil
} }
// generateChainId generates a random chain ID for tracking peer sender lifecycles.
func generateChainId() string {
b := make([]byte, 8)
_, _ = rand.Read(b)
return hex.EncodeToString(b)
}
func Init(ctx context.Context, config OlmConfig) (*Olm, error) { func Init(ctx context.Context, config OlmConfig) (*Olm, error) {
logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel)) logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel))
@@ -166,10 +178,13 @@ func Init(ctx context.Context, config OlmConfig) (*Olm, error) {
apiServer.SetAgent(config.Agent) apiServer.SetAgent(config.Agent)
newOlm := &Olm{ newOlm := &Olm{
logFile: logFile, logFile: logFile,
olmCtx: ctx, olmCtx: ctx,
apiServer: apiServer, apiServer: apiServer,
olmConfig: config, olmConfig: config,
stopPeerSends: make(map[string]func()),
stopPeerInits: make(map[string]func()),
jitPendingSites: make(map[int]string),
} }
newOlm.registerAPICallbacks() newOlm.registerAPICallbacks()
@@ -284,6 +299,21 @@ func (o *Olm) registerAPICallbacks() {
logger.Info("Processing power mode change request via API: mode=%s", req.Mode) logger.Info("Processing power mode change request via API: mode=%s", req.Mode)
return o.SetPowerMode(req.Mode) return o.SetPowerMode(req.Mode)
}, },
func(req api.JITConnectionRequest) error {
logger.Info("Processing JIT connect request via API: site=%s resource=%s", req.Site, req.Resource)
chainId := generateChainId()
o.peerSendMu.Lock()
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/init", map[string]interface{}{
"siteId": req.Site,
"resourceId": req.Resource,
"chainId": chainId,
}, 2*time.Second, 10)
o.stopPeerInits[chainId] = stopFunc
o.peerSendMu.Unlock()
return nil
},
) )
} }
@@ -385,6 +415,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
// Handler for peer handshake - adds exit node to holepunch rotation and notifies server // Handler for peer handshake - adds exit node to holepunch rotation and notifies server
o.websocket.RegisterHandler("olm/wg/peer/holepunch/site/add", o.handleWgPeerHolepunchAddSite) o.websocket.RegisterHandler("olm/wg/peer/holepunch/site/add", o.handleWgPeerHolepunchAddSite)
o.websocket.RegisterHandler("olm/wg/peer/chain/cancel", o.handleCancelChain)
o.websocket.RegisterHandler("olm/sync", o.handleSync) o.websocket.RegisterHandler("olm/sync", o.handleSync)
o.websocket.OnConnect(func() error { o.websocket.OnConnect(func() error {
@@ -427,7 +458,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
"userToken": userToken, "userToken": userToken,
"fingerprint": o.fingerprint, "fingerprint": o.fingerprint,
"postures": o.postures, "postures": o.postures,
}, 1*time.Second, 10) }, 2*time.Second, 10)
// Invoke onRegistered callback if configured // Invoke onRegistered callback if configured
if o.olmConfig.OnRegistered != nil { if o.olmConfig.OnRegistered != nil {
@@ -524,6 +555,23 @@ func (o *Olm) Close() {
o.stopRegister = nil o.stopRegister = nil
} }
// Stop all pending peer init and send senders before closing websocket
o.peerSendMu.Lock()
for _, stop := range o.stopPeerInits {
if stop != nil {
stop()
}
}
o.stopPeerInits = make(map[string]func())
for _, stop := range o.stopPeerSends {
if stop != nil {
stop()
}
}
o.stopPeerSends = make(map[string]func())
o.jitPendingSites = make(map[int]string)
o.peerSendMu.Unlock()
// send a disconnect message to the cloud to show disconnected // send a disconnect message to the cloud to show disconnected
if o.websocket != nil { if o.websocket != nil {
o.websocket.SendMessage("olm/disconnecting", map[string]any{}) o.websocket.SendMessage("olm/disconnecting", map[string]any{})

View File

@@ -20,36 +20,43 @@ func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) {
return return
} }
if o.stopPeerSend != nil {
o.stopPeerSend()
o.stopPeerSend = nil
}
jsonData, err := json.Marshal(msg.Data) jsonData, err := json.Marshal(msg.Data)
if err != nil { if err != nil {
logger.Error("Error marshaling data: %v", err) logger.Error("Error marshaling data: %v", err)
return return
} }
var siteConfig peers.SiteConfig var siteConfigMsg struct {
if err := json.Unmarshal(jsonData, &siteConfig); err != nil { peers.SiteConfig
ChainId string `json:"chainId"`
}
if err := json.Unmarshal(jsonData, &siteConfigMsg); err != nil {
logger.Error("Error unmarshaling add data: %v", err) logger.Error("Error unmarshaling add data: %v", err)
return return
} }
if siteConfig.PublicKey == "" { if siteConfigMsg.ChainId != "" {
logger.Warn("Skipping add-peer for site %d (%s): no public key available (site may not be connected)", siteConfig.SiteId, siteConfig.Name) o.peerSendMu.Lock()
if stop, ok := o.stopPeerSends[siteConfigMsg.ChainId]; ok {
stop()
delete(o.stopPeerSends, siteConfigMsg.ChainId)
}
o.peerSendMu.Unlock()
}
if siteConfigMsg.PublicKey == "" {
logger.Warn("Skipping add-peer for site %d (%s): no public key available (site may not be connected)", siteConfigMsg.SiteId, siteConfigMsg.Name)
return return
} }
_ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it _ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it
if err := o.peerManager.AddPeer(siteConfig); err != nil { if err := o.peerManager.AddPeer(siteConfigMsg.SiteConfig); err != nil {
logger.Error("Failed to add peer: %v", err) logger.Error("Failed to add peer: %v", err)
return return
} }
logger.Info("Successfully added peer for site %d", siteConfig.SiteId) logger.Info("Successfully added peer for site %d", siteConfigMsg.SiteId)
} }
func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) { func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) {
@@ -169,13 +176,21 @@ func (o *Olm) handleWgPeerRelay(msg websocket.WSMessage) {
return return
} }
var relayData peers.RelayPeerData var relayData struct {
peers.RelayPeerData
ChainId string `json:"chainId"`
}
if err := json.Unmarshal(jsonData, &relayData); err != nil { if err := json.Unmarshal(jsonData, &relayData); err != nil {
logger.Error("Error unmarshaling relay data: %v", err) logger.Error("Error unmarshaling relay data: %v", err)
return return
} }
if monitor := o.peerManager.GetPeerMonitor(); monitor != nil {
monitor.CancelRelaySend(relayData.ChainId)
}
primaryRelay, err := util.ResolveDomainUpstream(relayData.RelayEndpoint, o.tunnelConfig.PublicDNS) primaryRelay, err := util.ResolveDomainUpstream(relayData.RelayEndpoint, o.tunnelConfig.PublicDNS)
if err != nil { if err != nil {
logger.Error("Failed to resolve primary relay endpoint: %v", err) logger.Error("Failed to resolve primary relay endpoint: %v", err)
return return
@@ -202,13 +217,21 @@ func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) {
return return
} }
var relayData peers.UnRelayPeerData var relayData struct {
peers.UnRelayPeerData
ChainId string `json:"chainId"`
}
if err := json.Unmarshal(jsonData, &relayData); err != nil { if err := json.Unmarshal(jsonData, &relayData); err != nil {
logger.Error("Error unmarshaling relay data: %v", err) logger.Error("Error unmarshaling relay data: %v", err)
return return
} }
if monitor := o.peerManager.GetPeerMonitor(); monitor != nil {
monitor.CancelRelaySend(relayData.ChainId)
}
primaryRelay, err := util.ResolveDomainUpstream(relayData.Endpoint, o.tunnelConfig.PublicDNS) primaryRelay, err := util.ResolveDomainUpstream(relayData.Endpoint, o.tunnelConfig.PublicDNS)
if err != nil { if err != nil {
logger.Warn("Failed to resolve primary relay endpoint: %v", err) logger.Warn("Failed to resolve primary relay endpoint: %v", err)
} }
@@ -235,7 +258,8 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
} }
var handshakeData struct { var handshakeData struct {
SiteId int `json:"siteId"` SiteId int `json:"siteId"`
ChainId string `json:"chainId"`
ExitNode struct { ExitNode struct {
PublicKey string `json:"publicKey"` PublicKey string `json:"publicKey"`
Endpoint string `json:"endpoint"` Endpoint string `json:"endpoint"`
@@ -248,6 +272,19 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
return return
} }
// Stop the peer init sender for this chain, if any
if handshakeData.ChainId != "" {
o.peerSendMu.Lock()
if stop, ok := o.stopPeerInits[handshakeData.ChainId]; ok {
stop()
delete(o.stopPeerInits, handshakeData.ChainId)
}
// If this chain was initiated by a DNS-triggered JIT request, clear the
// pending entry so the site can be re-triggered if needed in the future.
delete(o.jitPendingSites, handshakeData.SiteId)
o.peerSendMu.Unlock()
}
// Get existing peer from PeerManager // Get existing peer from PeerManager
_, exists := o.peerManager.GetPeer(handshakeData.SiteId) _, exists := o.peerManager.GetPeer(handshakeData.SiteId)
if exists { if exists {
@@ -278,10 +315,72 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
// Send handshake acknowledgment back to server with retry // Send handshake acknowledgment back to server with retry, keyed by chainId
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ chainId := handshakeData.ChainId
"siteId": handshakeData.SiteId, if chainId == "" {
}, 1*time.Second, 10) chainId = generateChainId()
}
o.peerSendMu.Lock()
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
"siteId": handshakeData.SiteId,
"chainId": chainId,
}, 2*time.Second, 10)
o.stopPeerSends[chainId] = stopFunc
o.peerSendMu.Unlock()
logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint) logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint)
} }
func (o *Olm) handleCancelChain(msg websocket.WSMessage) {
logger.Debug("Received cancel-chain message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling cancel-chain data: %v", err)
return
}
var cancelData struct {
ChainId string `json:"chainId"`
}
if err := json.Unmarshal(jsonData, &cancelData); err != nil {
logger.Error("Error unmarshaling cancel-chain data: %v", err)
return
}
if cancelData.ChainId == "" {
logger.Warn("Received cancel-chain message with no chainId")
return
}
o.peerSendMu.Lock()
defer o.peerSendMu.Unlock()
found := false
if stop, ok := o.stopPeerInits[cancelData.ChainId]; ok {
stop()
delete(o.stopPeerInits, cancelData.ChainId)
found = true
}
// If this chain was a DNS-triggered JIT request, clear the pending entry so
// the site can be re-triggered on the next DNS lookup.
for siteId, chainId := range o.jitPendingSites {
if chainId == cancelData.ChainId {
delete(o.jitPendingSites, siteId)
break
}
}
if stop, ok := o.stopPeerSends[cancelData.ChainId]; ok {
stop()
delete(o.stopPeerSends, cancelData.ChainId)
found = true
}
if found {
logger.Info("Cancelled chain %s", cancelData.ChainId)
} else {
logger.Warn("Cancel-chain: no active sender found for chain %s", cancelData.ChainId)
}
}

View File

@@ -110,6 +110,19 @@ func (pm *PeerManager) GetAllPeers() []SiteConfig {
func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error { func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
pm.mu.Lock() pm.mu.Lock()
defer pm.mu.Unlock() defer pm.mu.Unlock()
for _, alias := range siteConfig.Aliases {
address := net.ParseIP(alias.AliasAddress)
if address == nil {
continue
}
pm.dnsProxy.AddDNSRecord(alias.Alias, address, siteConfig.SiteId)
}
if siteConfig.PublicKey == "" {
logger.Debug("Skip adding site %d because no pub key", siteConfig.SiteId)
return nil
}
// build the allowed IPs list from the remote subnets and aliases and add them to the peer // build the allowed IPs list from the remote subnets and aliases and add them to the peer
allowedIPs := make([]string, 0, len(siteConfig.RemoteSubnets)+len(siteConfig.Aliases)) allowedIPs := make([]string, 0, len(siteConfig.RemoteSubnets)+len(siteConfig.Aliases))
@@ -143,14 +156,7 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
if err := network.AddRoutes(siteConfig.RemoteSubnets, pm.interfaceName); err != nil { if err := network.AddRoutes(siteConfig.RemoteSubnets, pm.interfaceName); err != nil {
logger.Error("Failed to add routes for remote subnets: %v", err) logger.Error("Failed to add routes for remote subnets: %v", err)
} }
for _, alias := range siteConfig.Aliases {
address := net.ParseIP(alias.AliasAddress)
if address == nil {
continue
}
pm.dnsProxy.AddDNSRecord(alias.Alias, address)
}
monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0] monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0]
monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port
@@ -437,7 +443,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
if address == nil { if address == nil {
continue continue
} }
pm.dnsProxy.AddDNSRecord(alias.Alias, address) pm.dnsProxy.AddDNSRecord(alias.Alias, address, siteConfig.SiteId)
} }
pm.peerMonitor.UpdateHolepunchEndpoint(siteConfig.SiteId, siteConfig.Endpoint) pm.peerMonitor.UpdateHolepunchEndpoint(siteConfig.SiteId, siteConfig.Endpoint)
@@ -717,7 +723,7 @@ func (pm *PeerManager) AddAlias(siteId int, alias Alias) error {
address := net.ParseIP(alias.AliasAddress) address := net.ParseIP(alias.AliasAddress)
if address != nil { if address != nil {
pm.dnsProxy.AddDNSRecord(alias.Alias, address) pm.dnsProxy.AddDNSRecord(alias.Alias, address, siteId)
} }
// Add an allowed IP for the alias // Add an allowed IP for the alias

View File

@@ -2,6 +2,8 @@ package monitor
import ( import (
"context" "context"
"crypto/rand"
"encoding/hex"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@@ -31,11 +33,15 @@ type PeerMonitor struct {
monitors map[int]*Client monitors map[int]*Client
mutex sync.Mutex mutex sync.Mutex
running bool running bool
timeout time.Duration timeout time.Duration
maxAttempts int maxAttempts int
wsClient *websocket.Client wsClient *websocket.Client
publicDNS []string publicDNS []string
// Relay sender tracking
relaySends map[string]func()
relaySendMu sync.Mutex
// Netstack fields // Netstack fields
middleDev *middleDevice.MiddleDevice middleDev *middleDevice.MiddleDevice
localIP string localIP string
@@ -48,13 +54,13 @@ type PeerMonitor struct {
nsWg sync.WaitGroup nsWg sync.WaitGroup
// Holepunch testing fields // Holepunch testing fields
sharedBind *bind.SharedBind sharedBind *bind.SharedBind
holepunchTester *holepunch.HolepunchTester holepunchTester *holepunch.HolepunchTester
holepunchTimeout time.Duration holepunchTimeout time.Duration
holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing
holepunchStatus map[int]bool // siteID -> connected status holepunchStatus map[int]bool // siteID -> connected status
holepunchStopChan chan struct{} holepunchStopChan chan struct{}
holepunchUpdateChan chan struct{} holepunchUpdateChan chan struct{}
// Relay tracking fields // Relay tracking fields
relayedPeers map[int]bool // siteID -> whether the peer is currently relayed relayedPeers map[int]bool // siteID -> whether the peer is currently relayed
@@ -83,6 +89,12 @@ type PeerMonitor struct {
} }
// NewPeerMonitor creates a new peer monitor with the given callback // NewPeerMonitor creates a new peer monitor with the given callback
func generateChainId() string {
b := make([]byte, 8)
_, _ = rand.Read(b)
return hex.EncodeToString(b)
}
func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind, apiServer *api.API, publicDNS []string) *PeerMonitor { func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind, apiServer *api.API, publicDNS []string) *PeerMonitor {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
pm := &PeerMonitor{ pm := &PeerMonitor{
@@ -101,6 +113,7 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
holepunchEndpoints: make(map[int]string), holepunchEndpoints: make(map[int]string),
holepunchStatus: make(map[int]bool), holepunchStatus: make(map[int]bool),
relayedPeers: make(map[int]bool), relayedPeers: make(map[int]bool),
relaySends: make(map[string]func()),
holepunchMaxAttempts: 3, // Trigger relay after 3 consecutive failures holepunchMaxAttempts: 3, // Trigger relay after 3 consecutive failures
holepunchFailures: make(map[int]int), holepunchFailures: make(map[int]int),
// Rapid initial test settings: complete within ~1.5 seconds // Rapid initial test settings: complete within ~1.5 seconds
@@ -398,20 +411,23 @@ func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status Connectio
} }
} }
// sendRelay sends a relay message to the server // sendRelay sends a relay message to the server with retry, keyed by chainId
func (pm *PeerMonitor) sendRelay(siteID int) error { func (pm *PeerMonitor) sendRelay(siteID int) error {
if pm.wsClient == nil { if pm.wsClient == nil {
return fmt.Errorf("websocket client is nil") return fmt.Errorf("websocket client is nil")
} }
err := pm.wsClient.SendMessage("olm/wg/relay", map[string]interface{}{ chainId := generateChainId()
"siteId": siteID, stopFunc, _ := pm.wsClient.SendMessageInterval("olm/wg/relay", map[string]interface{}{
}) "siteId": siteID,
if err != nil { "chainId": chainId,
logger.Error("Failed to send registration message: %v", err) }, 2*time.Second, 10)
return err
} pm.relaySendMu.Lock()
logger.Info("Sent relay message") pm.relaySends[chainId] = stopFunc
pm.relaySendMu.Unlock()
logger.Info("Sent relay message for site %d (chain %s)", siteID, chainId)
return nil return nil
} }
@@ -421,23 +437,52 @@ func (pm *PeerMonitor) RequestRelay(siteID int) error {
return pm.sendRelay(siteID) return pm.sendRelay(siteID)
} }
// sendUnRelay sends an unrelay message to the server // sendUnRelay sends an unrelay message to the server with retry, keyed by chainId
func (pm *PeerMonitor) sendUnRelay(siteID int) error { func (pm *PeerMonitor) sendUnRelay(siteID int) error {
if pm.wsClient == nil { if pm.wsClient == nil {
return fmt.Errorf("websocket client is nil") return fmt.Errorf("websocket client is nil")
} }
err := pm.wsClient.SendMessage("olm/wg/unrelay", map[string]interface{}{ chainId := generateChainId()
"siteId": siteID, stopFunc, _ := pm.wsClient.SendMessageInterval("olm/wg/unrelay", map[string]interface{}{
}) "siteId": siteID,
if err != nil { "chainId": chainId,
logger.Error("Failed to send registration message: %v", err) }, 2*time.Second, 10)
return err
} pm.relaySendMu.Lock()
logger.Info("Sent unrelay message") pm.relaySends[chainId] = stopFunc
pm.relaySendMu.Unlock()
logger.Info("Sent unrelay message for site %d (chain %s)", siteID, chainId)
return nil return nil
} }
// CancelRelaySend stops the interval sender for the given chainId, if one exists.
// If chainId is empty, all active relay senders are stopped.
func (pm *PeerMonitor) CancelRelaySend(chainId string) {
pm.relaySendMu.Lock()
defer pm.relaySendMu.Unlock()
if chainId == "" {
for id, stop := range pm.relaySends {
if stop != nil {
stop()
}
delete(pm.relaySends, id)
}
logger.Info("Cancelled all relay senders")
return
}
if stop, ok := pm.relaySends[chainId]; ok {
stop()
delete(pm.relaySends, chainId)
logger.Info("Cancelled relay sender for chain %s", chainId)
} else {
logger.Warn("CancelRelaySend: no active sender for chain %s", chainId)
}
}
// Stop stops monitoring all peers // Stop stops monitoring all peers
func (pm *PeerMonitor) Stop() { func (pm *PeerMonitor) Stop() {
// Stop holepunch monitor first (outside of mutex to avoid deadlock) // Stop holepunch monitor first (outside of mutex to avoid deadlock)
@@ -536,7 +581,7 @@ func (pm *PeerMonitor) runHolepunchMonitor() {
pm.holepunchCurrentInterval = pm.holepunchMinInterval pm.holepunchCurrentInterval = pm.holepunchMinInterval
currentInterval := pm.holepunchCurrentInterval currentInterval := pm.holepunchCurrentInterval
pm.mutex.Unlock() pm.mutex.Unlock()
timer.Reset(currentInterval) timer.Reset(currentInterval)
logger.Debug("Holepunch monitor interval updated, reset to %v", currentInterval) logger.Debug("Holepunch monitor interval updated, reset to %v", currentInterval)
case <-timer.C: case <-timer.C:
@@ -679,6 +724,16 @@ func (pm *PeerMonitor) Close() {
// Stop holepunch monitor first (outside of mutex to avoid deadlock) // Stop holepunch monitor first (outside of mutex to avoid deadlock)
pm.stopHolepunchMonitor() pm.stopHolepunchMonitor()
// Stop all pending relay senders
pm.relaySendMu.Lock()
for chainId, stop := range pm.relaySends {
if stop != nil {
stop()
}
delete(pm.relaySends, chainId)
}
pm.relaySendMu.Unlock()
pm.mutex.Lock() pm.mutex.Lock()
defer pm.mutex.Unlock() defer pm.mutex.Unlock()