Compare commits

..

9 Commits
dev ... jit

Author SHA1 Message Date
Owen
c2b5ef96a4 Jit of aliases working 2026-03-12 17:26:46 -07:00
Owen
e326da3d3e Merge branch 'dev' into jit 2026-03-12 16:53:16 -07:00
Owen
22cd02ae15 Alias jit handler 2026-03-11 15:56:51 -07:00
Owen
e2690bcc03 Store site id 2026-03-06 16:19:00 -08:00
Owen
f2d0e6a14c Merge branch 'dev' into jit 2026-03-06 16:08:24 -08:00
Owen
809dbe77de Make chainId in relay message bckwd compat 2026-03-06 15:27:03 -08:00
Owen
c67c2a60a1 Handle canceling sends for relay 2026-03-06 15:15:31 -08:00
Owen
051c0fdfd8 Working jit with chain ids 2026-03-04 17:51:48 -08:00
Owen
e7507e0837 Add api endpoints to jit 2026-03-04 17:01:17 -08:00
11 changed files with 476 additions and 109 deletions

View File

@@ -78,6 +78,13 @@ type MetadataChangeRequest struct {
Postures map[string]any `json:"postures"` Postures map[string]any `json:"postures"`
} }
// JITConnectionRequest defines the structure for a dynamic Just-In-Time connection request.
// Either SiteID or ResourceID must be provided (but not necessarily both).
type JITConnectionRequest struct {
Site string `json:"site,omitempty"`
Resource string `json:"resource,omitempty"`
}
// API represents the HTTP server and its state // API represents the HTTP server and its state
type API struct { type API struct {
addr string addr string
@@ -92,6 +99,7 @@ type API struct {
onExit func() error onExit func() error
onRebind func() error onRebind func() error
onPowerMode func(PowerModeRequest) error onPowerMode func(PowerModeRequest) error
onJITConnect func(JITConnectionRequest) error
statusMu sync.RWMutex statusMu sync.RWMutex
peerStatuses map[int]*PeerStatus peerStatuses map[int]*PeerStatus
@@ -143,6 +151,7 @@ func (s *API) SetHandlers(
onExit func() error, onExit func() error,
onRebind func() error, onRebind func() error,
onPowerMode func(PowerModeRequest) error, onPowerMode func(PowerModeRequest) error,
onJITConnect func(JITConnectionRequest) error,
) { ) {
s.onConnect = onConnect s.onConnect = onConnect
s.onSwitchOrg = onSwitchOrg s.onSwitchOrg = onSwitchOrg
@@ -151,6 +160,7 @@ func (s *API) SetHandlers(
s.onExit = onExit s.onExit = onExit
s.onRebind = onRebind s.onRebind = onRebind
s.onPowerMode = onPowerMode s.onPowerMode = onPowerMode
s.onJITConnect = onJITConnect
} }
// Start starts the HTTP server // Start starts the HTTP server
@@ -169,6 +179,7 @@ func (s *API) Start() error {
mux.HandleFunc("/health", s.handleHealth) mux.HandleFunc("/health", s.handleHealth)
mux.HandleFunc("/rebind", s.handleRebind) mux.HandleFunc("/rebind", s.handleRebind)
mux.HandleFunc("/power-mode", s.handlePowerMode) mux.HandleFunc("/power-mode", s.handlePowerMode)
mux.HandleFunc("/jit-connect", s.handleJITConnect)
s.server = &http.Server{ s.server = &http.Server{
Handler: mux, Handler: mux,
@@ -633,6 +644,54 @@ func (s *API) handleRebind(w http.ResponseWriter, r *http.Request) {
}) })
} }
// handleJITConnect handles the /jit-connect endpoint.
// It initiates a dynamic Just-In-Time connection to a site identified by either
// a site or a resource. Exactly one of the two must be provided.
func (s *API) handleJITConnect(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req JITConnectionRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest)
return
}
// Validate that exactly one of site or resource is provided
if req.Site == "" && req.Resource == "" {
http.Error(w, "Missing required field: either site or resource must be provided", http.StatusBadRequest)
return
}
if req.Site != "" && req.Resource != "" {
http.Error(w, "Ambiguous request: provide either site or resource, not both", http.StatusBadRequest)
return
}
if req.Site != "" {
logger.Info("Received JIT connection request via API: site=%s", req.Site)
} else {
logger.Info("Received JIT connection request via API: resource=%s", req.Resource)
}
if s.onJITConnect != nil {
if err := s.onJITConnect(req); err != nil {
http.Error(w, fmt.Sprintf("JIT connection failed: %v", err), http.StatusInternalServerError)
return
}
} else {
http.Error(w, "JIT connect handler not configured", http.StatusNotImplemented)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusAccepted)
_ = json.NewEncoder(w).Encode(map[string]string{
"status": "JIT connection request accepted",
})
}
// handlePowerMode handles the /power-mode endpoint // handlePowerMode handles the /power-mode endpoint
// This allows changing the power mode between "normal" and "low" // This allows changing the power mode between "normal" and "low"
func (s *API) handlePowerMode(w http.ResponseWriter, r *http.Request) { func (s *API) handlePowerMode(w http.ResponseWriter, r *http.Request) {

View File

@@ -45,6 +45,11 @@ type DNSProxy struct {
tunnelActivePorts map[uint16]bool tunnelActivePorts map[uint16]bool
tunnelPortsLock sync.Mutex tunnelPortsLock sync.Mutex
// jitHandler is called when a local record is resolved for a site that may not be
// connected yet, giving the caller a chance to initiate a JIT connection.
// It is invoked asynchronously so it never blocks DNS resolution.
jitHandler func(siteId int)
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
wg sync.WaitGroup wg sync.WaitGroup
@@ -384,6 +389,16 @@ func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clie
response = p.checkLocalRecords(msg, question) response = p.checkLocalRecords(msg, question)
} }
// If a local A/AAAA record was found, notify the JIT handler so that the owning
// site can be connected on-demand if it is not yet active.
if response != nil && p.jitHandler != nil &&
(question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA) {
if siteId, ok := p.recordStore.GetSiteIdForDomain(question.Name); ok && siteId != 0 {
handler := p.jitHandler
go handler(siteId)
}
}
// If no local records, forward to upstream // If no local records, forward to upstream
if response == nil { if response == nil {
logger.Debug("No local record for %s, forwarding upstream to %v", question.Name, p.upstreamDNS) logger.Debug("No local record for %s, forwarding upstream to %v", question.Name, p.upstreamDNS)
@@ -718,11 +733,20 @@ func (p *DNSProxy) runPacketSender() {
} }
} }
// SetJITHandler registers a callback that is invoked whenever a local DNS record is
// resolved for an A or AAAA query. The siteId identifies which site owns the record.
// The handler is called in its own goroutine so it must be safe to call concurrently.
// Pass nil to disable JIT notifications.
func (p *DNSProxy) SetJITHandler(handler func(siteId int)) {
p.jitHandler = handler
}
// AddDNSRecord adds a DNS record to the local store // AddDNSRecord adds a DNS record to the local store
// domain should be a domain name (e.g., "example.com" or "example.com.") // domain should be a domain name (e.g., "example.com" or "example.com.")
// ip should be a valid IPv4 or IPv6 address // ip should be a valid IPv4 or IPv6 address
func (p *DNSProxy) AddDNSRecord(domain string, ip net.IP) error { func (p *DNSProxy) AddDNSRecord(domain string, ip net.IP, siteId int) error {
return p.recordStore.AddRecord(domain, ip) logger.Debug("Adding dns record for domain %s with IP %s (siteId=%d)", domain, ip.String(), siteId)
return p.recordStore.AddRecord(domain, ip, siteId)
} }
// RemoveDNSRecord removes a DNS record from the local store // RemoveDNSRecord removes a DNS record from the local store

View File

@@ -14,7 +14,7 @@ func TestCheckLocalRecordsNODATAForAAAA(t *testing.T) {
// Add an A record for a domain (no AAAA record) // Add an A record for a domain (no AAAA record)
ip := net.ParseIP("10.0.0.1") ip := net.ParseIP("10.0.0.1")
err := proxy.recordStore.AddRecord("myservice.internal", ip) err := proxy.recordStore.AddRecord("myservice.internal", ip, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add A record: %v", err) t.Fatalf("Failed to add A record: %v", err)
} }
@@ -64,7 +64,7 @@ func TestCheckLocalRecordsNODATAForA(t *testing.T) {
// Add an AAAA record for a domain (no A record) // Add an AAAA record for a domain (no A record)
ip := net.ParseIP("2001:db8::1") ip := net.ParseIP("2001:db8::1")
err := proxy.recordStore.AddRecord("ipv6only.internal", ip) err := proxy.recordStore.AddRecord("ipv6only.internal", ip, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add AAAA record: %v", err) t.Fatalf("Failed to add AAAA record: %v", err)
} }
@@ -113,7 +113,7 @@ func TestCheckLocalRecordsNonExistentDomain(t *testing.T) {
} }
// Add a record so the store isn't empty // Add a record so the store isn't empty
err := proxy.recordStore.AddRecord("exists.internal", net.ParseIP("10.0.0.1")) err := proxy.recordStore.AddRecord("exists.internal", net.ParseIP("10.0.0.1"), 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add record: %v", err) t.Fatalf("Failed to add record: %v", err)
} }
@@ -144,7 +144,7 @@ func TestCheckLocalRecordsNODATAWildcard(t *testing.T) {
// Add a wildcard A record (no AAAA) // Add a wildcard A record (no AAAA)
ip := net.ParseIP("10.0.0.1") ip := net.ParseIP("10.0.0.1")
err := proxy.recordStore.AddRecord("*.wildcard.internal", ip) err := proxy.recordStore.AddRecord("*.wildcard.internal", ip, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add wildcard A record: %v", err) t.Fatalf("Failed to add wildcard A record: %v", err)
} }

View File

@@ -22,6 +22,7 @@ const (
type recordSet struct { type recordSet struct {
A []net.IP A []net.IP
AAAA []net.IP AAAA []net.IP
SiteId int
} }
// DNSRecordStore manages local DNS records for A, AAAA, and PTR queries. // DNSRecordStore manages local DNS records for A, AAAA, and PTR queries.
@@ -46,8 +47,9 @@ func NewDNSRecordStore() *DNSRecordStore {
// domain should be in FQDN format (e.g., "example.com.") // domain should be in FQDN format (e.g., "example.com.")
// domain can contain wildcards: * (0+ chars) and ? (exactly 1 char) // domain can contain wildcards: * (0+ chars) and ? (exactly 1 char)
// ip should be a valid IPv4 or IPv6 address // ip should be a valid IPv4 or IPv6 address
// siteId is the site that owns this alias/domain
// Automatically adds a corresponding PTR record for non-wildcard domains // Automatically adds a corresponding PTR record for non-wildcard domains
func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error { func (s *DNSRecordStore) AddRecord(domain string, ip net.IP, siteId int) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@@ -69,12 +71,22 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
} }
if m[domain] == nil { if m[domain] == nil {
m[domain] = &recordSet{} m[domain] = &recordSet{SiteId: siteId}
} }
rs := m[domain] rs := m[domain]
if isV4 { if isV4 {
for _, existing := range rs.A {
if existing.Equal(ip) {
return nil
}
}
rs.A = append(rs.A, ip) rs.A = append(rs.A, ip)
} else { } else {
for _, existing := range rs.AAAA {
if existing.Equal(ip) {
return nil
}
}
rs.AAAA = append(rs.AAAA, ip) rs.AAAA = append(rs.AAAA, ip)
} }
@@ -85,6 +97,7 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
return nil return nil
} }
// AddPTRRecord adds a PTR record mapping an IP address to a domain name // AddPTRRecord adds a PTR record mapping an IP address to a domain name
// ip should be a valid IPv4 or IPv6 address // ip should be a valid IPv4 or IPv6 address
// domain should be in FQDN format (e.g., "example.com.") // domain should be in FQDN format (e.g., "example.com.")
@@ -179,6 +192,30 @@ func (s *DNSRecordStore) RemovePTRRecord(ip net.IP) {
delete(s.ptrRecords, ip.String()) delete(s.ptrRecords, ip.String())
} }
// GetSiteIdForDomain returns the siteId associated with the given domain.
// It checks exact matches first, then wildcard patterns.
// The second return value is false if the domain is not found in local records.
func (s *DNSRecordStore) GetSiteIdForDomain(domain string) (int, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
domain = strings.ToLower(dns.Fqdn(domain))
// Check exact match first
if rs, exists := s.exact[domain]; exists {
return rs.SiteId, true
}
// Check wildcard matches
for pattern, rs := range s.wildcards {
if matchWildcard(pattern, domain) {
return rs.SiteId, true
}
}
return 0, false
}
// GetRecords returns all IP addresses for a domain and record type. // GetRecords returns all IP addresses for a domain and record type.
// The second return value indicates whether the domain exists at all // The second return value indicates whether the domain exists at all
// (true = domain exists, use NODATA if no records; false = NXDOMAIN). // (true = domain exists, use NODATA if no records; false = NXDOMAIN).

View File

@@ -170,14 +170,14 @@ func TestDNSRecordStoreWildcard(t *testing.T) {
// Add wildcard records // Add wildcard records
wildcardIP := net.ParseIP("10.0.0.1") wildcardIP := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.autoco.internal", wildcardIP) err := store.AddRecord("*.autoco.internal", wildcardIP, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err) t.Fatalf("Failed to add wildcard record: %v", err)
} }
// Add exact record // Add exact record
exactIP := net.ParseIP("10.0.0.2") exactIP := net.ParseIP("10.0.0.2")
err = store.AddRecord("exact.autoco.internal", exactIP) err = store.AddRecord("exact.autoco.internal", exactIP, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add exact record: %v", err) t.Fatalf("Failed to add exact record: %v", err)
} }
@@ -221,7 +221,7 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) {
// Add complex wildcard pattern // Add complex wildcard pattern
ip1 := net.ParseIP("10.0.0.1") ip1 := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.host-0?.autoco.internal", ip1) err := store.AddRecord("*.host-0?.autoco.internal", ip1, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err) t.Fatalf("Failed to add wildcard record: %v", err)
} }
@@ -262,7 +262,7 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) {
// Add wildcard record // Add wildcard record
ip := net.ParseIP("10.0.0.1") ip := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.autoco.internal", ip) err := store.AddRecord("*.autoco.internal", ip, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err) t.Fatalf("Failed to add wildcard record: %v", err)
} }
@@ -297,18 +297,18 @@ func TestDNSRecordStoreMultipleWildcards(t *testing.T) {
ip2 := net.ParseIP("10.0.0.2") ip2 := net.ParseIP("10.0.0.2")
ip3 := net.ParseIP("10.0.0.3") ip3 := net.ParseIP("10.0.0.3")
err := store.AddRecord("*.prod.autoco.internal", ip1) err := store.AddRecord("*.prod.autoco.internal", ip1, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add first wildcard: %v", err) t.Fatalf("Failed to add first wildcard: %v", err)
} }
err = store.AddRecord("*.dev.autoco.internal", ip2) err = store.AddRecord("*.dev.autoco.internal", ip2, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add second wildcard: %v", err) t.Fatalf("Failed to add second wildcard: %v", err)
} }
// Add a broader wildcard that matches both // Add a broader wildcard that matches both
err = store.AddRecord("*.autoco.internal", ip3) err = store.AddRecord("*.autoco.internal", ip3, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add third wildcard: %v", err) t.Fatalf("Failed to add third wildcard: %v", err)
} }
@@ -337,7 +337,7 @@ func TestDNSRecordStoreIPv6Wildcard(t *testing.T) {
// Add IPv6 wildcard record // Add IPv6 wildcard record
ip := net.ParseIP("2001:db8::1") ip := net.ParseIP("2001:db8::1")
err := store.AddRecord("*.autoco.internal", ip) err := store.AddRecord("*.autoco.internal", ip, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add IPv6 wildcard record: %v", err) t.Fatalf("Failed to add IPv6 wildcard record: %v", err)
} }
@@ -357,7 +357,7 @@ func TestHasRecordWildcard(t *testing.T) {
// Add wildcard record // Add wildcard record
ip := net.ParseIP("10.0.0.1") ip := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.autoco.internal", ip) err := store.AddRecord("*.autoco.internal", ip, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err) t.Fatalf("Failed to add wildcard record: %v", err)
} }
@@ -378,7 +378,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
// Add record with mixed case // Add record with mixed case
ip := net.ParseIP("10.0.0.1") ip := net.ParseIP("10.0.0.1")
err := store.AddRecord("MyHost.AutoCo.Internal", ip) err := store.AddRecord("MyHost.AutoCo.Internal", ip, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add mixed case record: %v", err) t.Fatalf("Failed to add mixed case record: %v", err)
} }
@@ -403,7 +403,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
// Test wildcard with mixed case // Test wildcard with mixed case
wildcardIP := net.ParseIP("10.0.0.2") wildcardIP := net.ParseIP("10.0.0.2")
err = store.AddRecord("*.Example.Com", wildcardIP) err = store.AddRecord("*.Example.Com", wildcardIP, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add mixed case wildcard: %v", err) t.Fatalf("Failed to add mixed case wildcard: %v", err)
} }
@@ -689,7 +689,7 @@ func TestClearPTRRecords(t *testing.T) {
store.AddPTRRecord(ip2, "host2.example.com.") store.AddPTRRecord(ip2, "host2.example.com.")
// Add some A records too // Add some A records too
store.AddRecord("test.example.com.", net.ParseIP("10.0.0.1")) store.AddRecord("test.example.com.", net.ParseIP("10.0.0.1"), 0)
// Verify PTR records exist // Verify PTR records exist
if !store.HasPTRRecord("1.1.168.192.in-addr.arpa.") { if !store.HasPTRRecord("1.1.168.192.in-addr.arpa.") {
@@ -719,7 +719,7 @@ func TestAutomaticPTRRecordOnAdd(t *testing.T) {
// Add an A record - should automatically add PTR record // Add an A record - should automatically add PTR record
domain := "host.example.com." domain := "host.example.com."
ip := net.ParseIP("192.168.1.100") ip := net.ParseIP("192.168.1.100")
err := store.AddRecord(domain, ip) err := store.AddRecord(domain, ip, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add A record: %v", err) t.Fatalf("Failed to add A record: %v", err)
} }
@@ -737,7 +737,7 @@ func TestAutomaticPTRRecordOnAdd(t *testing.T) {
// Add AAAA record - should also automatically add PTR record // Add AAAA record - should also automatically add PTR record
domain6 := "ipv6host.example.com." domain6 := "ipv6host.example.com."
ip6 := net.ParseIP("2001:db8::1") ip6 := net.ParseIP("2001:db8::1")
err = store.AddRecord(domain6, ip6) err = store.AddRecord(domain6, ip6, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add AAAA record: %v", err) t.Fatalf("Failed to add AAAA record: %v", err)
} }
@@ -759,7 +759,7 @@ func TestAutomaticPTRRecordOnRemove(t *testing.T) {
// Add an A record (with automatic PTR) // Add an A record (with automatic PTR)
domain := "host.example.com." domain := "host.example.com."
ip := net.ParseIP("192.168.1.100") ip := net.ParseIP("192.168.1.100")
store.AddRecord(domain, ip) store.AddRecord(domain, ip, 0)
// Verify PTR exists // Verify PTR exists
reverseDomain := "100.1.168.192.in-addr.arpa." reverseDomain := "100.1.168.192.in-addr.arpa."
@@ -789,8 +789,8 @@ func TestAutomaticPTRRecordOnRemoveAll(t *testing.T) {
domain := "host.example.com." domain := "host.example.com."
ip1 := net.ParseIP("192.168.1.100") ip1 := net.ParseIP("192.168.1.100")
ip2 := net.ParseIP("192.168.1.101") ip2 := net.ParseIP("192.168.1.101")
store.AddRecord(domain, ip1) store.AddRecord(domain, ip1, 0)
store.AddRecord(domain, ip2) store.AddRecord(domain, ip2, 0)
// Verify both PTR records exist // Verify both PTR records exist
reverseDomain1 := "100.1.168.192.in-addr.arpa." reverseDomain1 := "100.1.168.192.in-addr.arpa."
@@ -820,7 +820,7 @@ func TestNoPTRForWildcardRecords(t *testing.T) {
// Add wildcard record - should NOT create PTR record // Add wildcard record - should NOT create PTR record
domain := "*.example.com." domain := "*.example.com."
ip := net.ParseIP("192.168.1.100") ip := net.ParseIP("192.168.1.100")
err := store.AddRecord(domain, ip) err := store.AddRecord(domain, ip, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err) t.Fatalf("Failed to add wildcard record: %v", err)
} }
@@ -844,7 +844,7 @@ func TestPTRRecordOverwrite(t *testing.T) {
// Add first domain with IP // Add first domain with IP
domain1 := "host1.example.com." domain1 := "host1.example.com."
ip := net.ParseIP("192.168.1.100") ip := net.ParseIP("192.168.1.100")
store.AddRecord(domain1, ip) store.AddRecord(domain1, ip, 0)
// Verify PTR points to first domain // Verify PTR points to first domain
reverseDomain := "100.1.168.192.in-addr.arpa." reverseDomain := "100.1.168.192.in-addr.arpa."
@@ -858,7 +858,7 @@ func TestPTRRecordOverwrite(t *testing.T) {
// Add second domain with same IP - should overwrite PTR // Add second domain with same IP - should overwrite PTR
domain2 := "host2.example.com." domain2 := "host2.example.com."
store.AddRecord(domain2, ip) store.AddRecord(domain2, ip, 0)
// Verify PTR now points to second domain (last one added) // Verify PTR now points to second domain (last one added)
result, ok = store.GetPTRRecord(reverseDomain) result, ok = store.GetPTRRecord(reverseDomain)

View File

@@ -7,6 +7,7 @@ import (
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/fosrl/newt/logger" "github.com/fosrl/newt/logger"
"github.com/fosrl/newt/network" "github.com/fosrl/newt/network"
@@ -174,11 +175,7 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) {
for i := range wgData.Sites { for i := range wgData.Sites {
site := wgData.Sites[i] site := wgData.Sites[i]
if site.PublicKey == "" { if site.PublicKey != "" {
logger.Warn("Skipping site %d (%s): no public key available (site may not be connected)", site.SiteId, site.Name)
continue
}
var siteEndpoint string var siteEndpoint string
// here we are going to take the relay endpoint if it exists which means we requested a relay for this peer // here we are going to take the relay endpoint if it exists which means we requested a relay for this peer
if site.RelayEndpoint != "" { if site.RelayEndpoint != "" {
@@ -188,7 +185,9 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) {
} }
o.apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false) o.apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false)
}
// we still call this to add the aliases for jit lookup but we just do that then pass inside. need to skip the above so we dont add to the api
if err := o.peerManager.AddPeer(site); err != nil { if err := o.peerManager.AddPeer(site); err != nil {
logger.Error("Failed to add peer: %v", err) logger.Error("Failed to add peer: %v", err)
return return
@@ -203,6 +202,36 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) {
logger.Error("Failed to start DNS proxy: %v", err) logger.Error("Failed to start DNS proxy: %v", err)
} }
// Register JIT handler: when the DNS proxy resolves a local record, check whether
// the owning site is already connected and, if not, initiate a JIT connection.
o.dnsProxy.SetJITHandler(func(siteId int) {
if o.peerManager == nil || o.websocket == nil {
return
}
// Site already has an active peer connection - nothing to do.
if _, exists := o.peerManager.GetPeer(siteId); exists {
return
}
o.peerSendMu.Lock()
defer o.peerSendMu.Unlock()
// A JIT request for this site is already in-flight - avoid duplicate sends.
if _, pending := o.jitPendingSites[siteId]; pending {
return
}
chainId := generateChainId()
logger.Info("DNS-triggered JIT connect for site %d (chainId=%s)", siteId, chainId)
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/init", map[string]interface{}{
"siteId": siteId,
"chainId": chainId,
}, 2*time.Second, 10)
o.stopPeerInits[chainId] = stopFunc
o.jitPendingSites[siteId] = chainId
})
if o.tunnelConfig.OverrideDNS { if o.tunnelConfig.OverrideDNS {
// Set up DNS override to use our DNS proxy // Set up DNS override to use our DNS proxy
if err := dnsOverride.SetupDNSOverride(o.tunnelConfig.InterfaceName, o.dnsProxy.GetProxyIP()); err != nil { if err := dnsOverride.SetupDNSOverride(o.tunnelConfig.InterfaceName, o.dnsProxy.GetProxyIP()); err != nil {

View File

@@ -2,6 +2,7 @@ package olm
import ( import (
"encoding/json" "encoding/json"
"fmt"
"time" "time"
"github.com/fosrl/newt/holepunch" "github.com/fosrl/newt/holepunch"
@@ -220,6 +221,7 @@ func (o *Olm) handleSync(msg websocket.WSMessage) {
logger.Info("Sync: Adding new peer for site %d", siteId) logger.Info("Sync: Adding new peer for site %d", siteId)
o.holePunchManager.TriggerHolePunch() o.holePunchManager.TriggerHolePunch()
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
// // TODO: do we need to send the message to the cloud to add the peer that way? // // TODO: do we need to send the message to the cloud to add the peer that way?
// if err := o.peerManager.AddPeer(expectedSite); err != nil { // if err := o.peerManager.AddPeer(expectedSite); err != nil {
@@ -230,9 +232,17 @@ func (o *Olm) handleSync(msg websocket.WSMessage) {
// add the peer via the server // add the peer via the server
// this is important because newt needs to get triggered as well to add the peer once the hp is complete // this is important because newt needs to get triggered as well to add the peer once the hp is complete
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ chainId := fmt.Sprintf("sync-%d", expectedSite.SiteId)
o.peerSendMu.Lock()
if stop, ok := o.stopPeerSends[chainId]; ok {
stop()
}
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
"siteId": expectedSite.SiteId, "siteId": expectedSite.SiteId,
}, 1*time.Second, 10) "chainId": chainId,
}, 2*time.Second, 10)
o.stopPeerSends[chainId] = stopFunc
o.peerSendMu.Unlock()
} else { } else {
// Existing peer - check if update is needed // Existing peer - check if update is needed

View File

@@ -2,6 +2,8 @@ package olm
import ( import (
"context" "context"
"crypto/rand"
"encoding/hex"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@@ -65,7 +67,10 @@ type Olm struct {
stopRegister func() stopRegister func()
updateRegister func(newData any) updateRegister func(newData any)
stopPeerSend func() stopPeerSends map[string]func()
stopPeerInits map[string]func()
jitPendingSites map[int]string // siteId -> chainId for in-flight JIT requests
peerSendMu sync.Mutex
// WaitGroup to track tunnel lifecycle // WaitGroup to track tunnel lifecycle
tunnelWg sync.WaitGroup tunnelWg sync.WaitGroup
@@ -116,6 +121,13 @@ func (o *Olm) initTunnelInfo(clientID string) error {
return nil return nil
} }
// generateChainId generates a random chain ID for tracking peer sender lifecycles.
func generateChainId() string {
b := make([]byte, 8)
_, _ = rand.Read(b)
return hex.EncodeToString(b)
}
func Init(ctx context.Context, config OlmConfig) (*Olm, error) { func Init(ctx context.Context, config OlmConfig) (*Olm, error) {
logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel)) logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel))
@@ -170,6 +182,9 @@ func Init(ctx context.Context, config OlmConfig) (*Olm, error) {
olmCtx: ctx, olmCtx: ctx,
apiServer: apiServer, apiServer: apiServer,
olmConfig: config, olmConfig: config,
stopPeerSends: make(map[string]func()),
stopPeerInits: make(map[string]func()),
jitPendingSites: make(map[int]string),
} }
newOlm.registerAPICallbacks() newOlm.registerAPICallbacks()
@@ -284,6 +299,21 @@ func (o *Olm) registerAPICallbacks() {
logger.Info("Processing power mode change request via API: mode=%s", req.Mode) logger.Info("Processing power mode change request via API: mode=%s", req.Mode)
return o.SetPowerMode(req.Mode) return o.SetPowerMode(req.Mode)
}, },
func(req api.JITConnectionRequest) error {
logger.Info("Processing JIT connect request via API: site=%s resource=%s", req.Site, req.Resource)
chainId := generateChainId()
o.peerSendMu.Lock()
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/init", map[string]interface{}{
"siteId": req.Site,
"resourceId": req.Resource,
"chainId": chainId,
}, 2*time.Second, 10)
o.stopPeerInits[chainId] = stopFunc
o.peerSendMu.Unlock()
return nil
},
) )
} }
@@ -385,6 +415,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
// Handler for peer handshake - adds exit node to holepunch rotation and notifies server // Handler for peer handshake - adds exit node to holepunch rotation and notifies server
o.websocket.RegisterHandler("olm/wg/peer/holepunch/site/add", o.handleWgPeerHolepunchAddSite) o.websocket.RegisterHandler("olm/wg/peer/holepunch/site/add", o.handleWgPeerHolepunchAddSite)
o.websocket.RegisterHandler("olm/wg/peer/chain/cancel", o.handleCancelChain)
o.websocket.RegisterHandler("olm/sync", o.handleSync) o.websocket.RegisterHandler("olm/sync", o.handleSync)
o.websocket.OnConnect(func() error { o.websocket.OnConnect(func() error {
@@ -427,7 +458,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
"userToken": userToken, "userToken": userToken,
"fingerprint": o.fingerprint, "fingerprint": o.fingerprint,
"postures": o.postures, "postures": o.postures,
}, 1*time.Second, 10) }, 2*time.Second, 10)
// Invoke onRegistered callback if configured // Invoke onRegistered callback if configured
if o.olmConfig.OnRegistered != nil { if o.olmConfig.OnRegistered != nil {
@@ -524,6 +555,23 @@ func (o *Olm) Close() {
o.stopRegister = nil o.stopRegister = nil
} }
// Stop all pending peer init and send senders before closing websocket
o.peerSendMu.Lock()
for _, stop := range o.stopPeerInits {
if stop != nil {
stop()
}
}
o.stopPeerInits = make(map[string]func())
for _, stop := range o.stopPeerSends {
if stop != nil {
stop()
}
}
o.stopPeerSends = make(map[string]func())
o.jitPendingSites = make(map[int]string)
o.peerSendMu.Unlock()
// send a disconnect message to the cloud to show disconnected // send a disconnect message to the cloud to show disconnected
if o.websocket != nil { if o.websocket != nil {
o.websocket.SendMessage("olm/disconnecting", map[string]any{}) o.websocket.SendMessage("olm/disconnecting", map[string]any{})

View File

@@ -20,36 +20,43 @@ func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) {
return return
} }
if o.stopPeerSend != nil {
o.stopPeerSend()
o.stopPeerSend = nil
}
jsonData, err := json.Marshal(msg.Data) jsonData, err := json.Marshal(msg.Data)
if err != nil { if err != nil {
logger.Error("Error marshaling data: %v", err) logger.Error("Error marshaling data: %v", err)
return return
} }
var siteConfig peers.SiteConfig var siteConfigMsg struct {
if err := json.Unmarshal(jsonData, &siteConfig); err != nil { peers.SiteConfig
ChainId string `json:"chainId"`
}
if err := json.Unmarshal(jsonData, &siteConfigMsg); err != nil {
logger.Error("Error unmarshaling add data: %v", err) logger.Error("Error unmarshaling add data: %v", err)
return return
} }
if siteConfig.PublicKey == "" { if siteConfigMsg.ChainId != "" {
logger.Warn("Skipping add-peer for site %d (%s): no public key available (site may not be connected)", siteConfig.SiteId, siteConfig.Name) o.peerSendMu.Lock()
if stop, ok := o.stopPeerSends[siteConfigMsg.ChainId]; ok {
stop()
delete(o.stopPeerSends, siteConfigMsg.ChainId)
}
o.peerSendMu.Unlock()
}
if siteConfigMsg.PublicKey == "" {
logger.Warn("Skipping add-peer for site %d (%s): no public key available (site may not be connected)", siteConfigMsg.SiteId, siteConfigMsg.Name)
return return
} }
_ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it _ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it
if err := o.peerManager.AddPeer(siteConfig); err != nil { if err := o.peerManager.AddPeer(siteConfigMsg.SiteConfig); err != nil {
logger.Error("Failed to add peer: %v", err) logger.Error("Failed to add peer: %v", err)
return return
} }
logger.Info("Successfully added peer for site %d", siteConfig.SiteId) logger.Info("Successfully added peer for site %d", siteConfigMsg.SiteId)
} }
func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) { func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) {
@@ -169,13 +176,21 @@ func (o *Olm) handleWgPeerRelay(msg websocket.WSMessage) {
return return
} }
var relayData peers.RelayPeerData var relayData struct {
peers.RelayPeerData
ChainId string `json:"chainId"`
}
if err := json.Unmarshal(jsonData, &relayData); err != nil { if err := json.Unmarshal(jsonData, &relayData); err != nil {
logger.Error("Error unmarshaling relay data: %v", err) logger.Error("Error unmarshaling relay data: %v", err)
return return
} }
if monitor := o.peerManager.GetPeerMonitor(); monitor != nil {
monitor.CancelRelaySend(relayData.ChainId)
}
primaryRelay, err := util.ResolveDomainUpstream(relayData.RelayEndpoint, o.tunnelConfig.PublicDNS) primaryRelay, err := util.ResolveDomainUpstream(relayData.RelayEndpoint, o.tunnelConfig.PublicDNS)
if err != nil { if err != nil {
logger.Error("Failed to resolve primary relay endpoint: %v", err) logger.Error("Failed to resolve primary relay endpoint: %v", err)
return return
@@ -202,13 +217,21 @@ func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) {
return return
} }
var relayData peers.UnRelayPeerData var relayData struct {
peers.UnRelayPeerData
ChainId string `json:"chainId"`
}
if err := json.Unmarshal(jsonData, &relayData); err != nil { if err := json.Unmarshal(jsonData, &relayData); err != nil {
logger.Error("Error unmarshaling relay data: %v", err) logger.Error("Error unmarshaling relay data: %v", err)
return return
} }
if monitor := o.peerManager.GetPeerMonitor(); monitor != nil {
monitor.CancelRelaySend(relayData.ChainId)
}
primaryRelay, err := util.ResolveDomainUpstream(relayData.Endpoint, o.tunnelConfig.PublicDNS) primaryRelay, err := util.ResolveDomainUpstream(relayData.Endpoint, o.tunnelConfig.PublicDNS)
if err != nil { if err != nil {
logger.Warn("Failed to resolve primary relay endpoint: %v", err) logger.Warn("Failed to resolve primary relay endpoint: %v", err)
} }
@@ -236,6 +259,7 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
var handshakeData struct { var handshakeData struct {
SiteId int `json:"siteId"` SiteId int `json:"siteId"`
ChainId string `json:"chainId"`
ExitNode struct { ExitNode struct {
PublicKey string `json:"publicKey"` PublicKey string `json:"publicKey"`
Endpoint string `json:"endpoint"` Endpoint string `json:"endpoint"`
@@ -248,6 +272,19 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
return return
} }
// Stop the peer init sender for this chain, if any
if handshakeData.ChainId != "" {
o.peerSendMu.Lock()
if stop, ok := o.stopPeerInits[handshakeData.ChainId]; ok {
stop()
delete(o.stopPeerInits, handshakeData.ChainId)
}
// If this chain was initiated by a DNS-triggered JIT request, clear the
// pending entry so the site can be re-triggered if needed in the future.
delete(o.jitPendingSites, handshakeData.SiteId)
o.peerSendMu.Unlock()
}
// Get existing peer from PeerManager // Get existing peer from PeerManager
_, exists := o.peerManager.GetPeer(handshakeData.SiteId) _, exists := o.peerManager.GetPeer(handshakeData.SiteId)
if exists { if exists {
@@ -278,10 +315,72 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
// Send handshake acknowledgment back to server with retry // Send handshake acknowledgment back to server with retry, keyed by chainId
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ chainId := handshakeData.ChainId
if chainId == "" {
chainId = generateChainId()
}
o.peerSendMu.Lock()
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
"siteId": handshakeData.SiteId, "siteId": handshakeData.SiteId,
}, 1*time.Second, 10) "chainId": chainId,
}, 2*time.Second, 10)
o.stopPeerSends[chainId] = stopFunc
o.peerSendMu.Unlock()
logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint) logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint)
} }
func (o *Olm) handleCancelChain(msg websocket.WSMessage) {
logger.Debug("Received cancel-chain message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling cancel-chain data: %v", err)
return
}
var cancelData struct {
ChainId string `json:"chainId"`
}
if err := json.Unmarshal(jsonData, &cancelData); err != nil {
logger.Error("Error unmarshaling cancel-chain data: %v", err)
return
}
if cancelData.ChainId == "" {
logger.Warn("Received cancel-chain message with no chainId")
return
}
o.peerSendMu.Lock()
defer o.peerSendMu.Unlock()
found := false
if stop, ok := o.stopPeerInits[cancelData.ChainId]; ok {
stop()
delete(o.stopPeerInits, cancelData.ChainId)
found = true
}
// If this chain was a DNS-triggered JIT request, clear the pending entry so
// the site can be re-triggered on the next DNS lookup.
for siteId, chainId := range o.jitPendingSites {
if chainId == cancelData.ChainId {
delete(o.jitPendingSites, siteId)
break
}
}
if stop, ok := o.stopPeerSends[cancelData.ChainId]; ok {
stop()
delete(o.stopPeerSends, cancelData.ChainId)
found = true
}
if found {
logger.Info("Cancelled chain %s", cancelData.ChainId)
} else {
logger.Warn("Cancel-chain: no active sender found for chain %s", cancelData.ChainId)
}
}

View File

@@ -111,6 +111,19 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
pm.mu.Lock() pm.mu.Lock()
defer pm.mu.Unlock() defer pm.mu.Unlock()
for _, alias := range siteConfig.Aliases {
address := net.ParseIP(alias.AliasAddress)
if address == nil {
continue
}
pm.dnsProxy.AddDNSRecord(alias.Alias, address, siteConfig.SiteId)
}
if siteConfig.PublicKey == "" {
logger.Debug("Skip adding site %d because no pub key", siteConfig.SiteId)
return nil
}
// build the allowed IPs list from the remote subnets and aliases and add them to the peer // build the allowed IPs list from the remote subnets and aliases and add them to the peer
allowedIPs := make([]string, 0, len(siteConfig.RemoteSubnets)+len(siteConfig.Aliases)) allowedIPs := make([]string, 0, len(siteConfig.RemoteSubnets)+len(siteConfig.Aliases))
allowedIPs = append(allowedIPs, siteConfig.RemoteSubnets...) allowedIPs = append(allowedIPs, siteConfig.RemoteSubnets...)
@@ -143,13 +156,6 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
if err := network.AddRoutes(siteConfig.RemoteSubnets, pm.interfaceName); err != nil { if err := network.AddRoutes(siteConfig.RemoteSubnets, pm.interfaceName); err != nil {
logger.Error("Failed to add routes for remote subnets: %v", err) logger.Error("Failed to add routes for remote subnets: %v", err)
} }
for _, alias := range siteConfig.Aliases {
address := net.ParseIP(alias.AliasAddress)
if address == nil {
continue
}
pm.dnsProxy.AddDNSRecord(alias.Alias, address)
}
monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0] monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0]
monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port
@@ -437,7 +443,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
if address == nil { if address == nil {
continue continue
} }
pm.dnsProxy.AddDNSRecord(alias.Alias, address) pm.dnsProxy.AddDNSRecord(alias.Alias, address, siteConfig.SiteId)
} }
pm.peerMonitor.UpdateHolepunchEndpoint(siteConfig.SiteId, siteConfig.Endpoint) pm.peerMonitor.UpdateHolepunchEndpoint(siteConfig.SiteId, siteConfig.Endpoint)
@@ -717,7 +723,7 @@ func (pm *PeerManager) AddAlias(siteId int, alias Alias) error {
address := net.ParseIP(alias.AliasAddress) address := net.ParseIP(alias.AliasAddress)
if address != nil { if address != nil {
pm.dnsProxy.AddDNSRecord(alias.Alias, address) pm.dnsProxy.AddDNSRecord(alias.Alias, address, siteId)
} }
// Add an allowed IP for the alias // Add an allowed IP for the alias

View File

@@ -2,6 +2,8 @@ package monitor
import ( import (
"context" "context"
"crypto/rand"
"encoding/hex"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@@ -36,6 +38,10 @@ type PeerMonitor struct {
wsClient *websocket.Client wsClient *websocket.Client
publicDNS []string publicDNS []string
// Relay sender tracking
relaySends map[string]func()
relaySendMu sync.Mutex
// Netstack fields // Netstack fields
middleDev *middleDevice.MiddleDevice middleDev *middleDevice.MiddleDevice
localIP string localIP string
@@ -83,6 +89,12 @@ type PeerMonitor struct {
} }
// NewPeerMonitor creates a new peer monitor with the given callback // NewPeerMonitor creates a new peer monitor with the given callback
func generateChainId() string {
b := make([]byte, 8)
_, _ = rand.Read(b)
return hex.EncodeToString(b)
}
func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind, apiServer *api.API, publicDNS []string) *PeerMonitor { func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind, apiServer *api.API, publicDNS []string) *PeerMonitor {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
pm := &PeerMonitor{ pm := &PeerMonitor{
@@ -101,6 +113,7 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
holepunchEndpoints: make(map[int]string), holepunchEndpoints: make(map[int]string),
holepunchStatus: make(map[int]bool), holepunchStatus: make(map[int]bool),
relayedPeers: make(map[int]bool), relayedPeers: make(map[int]bool),
relaySends: make(map[string]func()),
holepunchMaxAttempts: 3, // Trigger relay after 3 consecutive failures holepunchMaxAttempts: 3, // Trigger relay after 3 consecutive failures
holepunchFailures: make(map[int]int), holepunchFailures: make(map[int]int),
// Rapid initial test settings: complete within ~1.5 seconds // Rapid initial test settings: complete within ~1.5 seconds
@@ -398,20 +411,23 @@ func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status Connectio
} }
} }
// sendRelay sends a relay message to the server // sendRelay sends a relay message to the server with retry, keyed by chainId
func (pm *PeerMonitor) sendRelay(siteID int) error { func (pm *PeerMonitor) sendRelay(siteID int) error {
if pm.wsClient == nil { if pm.wsClient == nil {
return fmt.Errorf("websocket client is nil") return fmt.Errorf("websocket client is nil")
} }
err := pm.wsClient.SendMessage("olm/wg/relay", map[string]interface{}{ chainId := generateChainId()
stopFunc, _ := pm.wsClient.SendMessageInterval("olm/wg/relay", map[string]interface{}{
"siteId": siteID, "siteId": siteID,
}) "chainId": chainId,
if err != nil { }, 2*time.Second, 10)
logger.Error("Failed to send registration message: %v", err)
return err pm.relaySendMu.Lock()
} pm.relaySends[chainId] = stopFunc
logger.Info("Sent relay message") pm.relaySendMu.Unlock()
logger.Info("Sent relay message for site %d (chain %s)", siteID, chainId)
return nil return nil
} }
@@ -421,23 +437,52 @@ func (pm *PeerMonitor) RequestRelay(siteID int) error {
return pm.sendRelay(siteID) return pm.sendRelay(siteID)
} }
// sendUnRelay sends an unrelay message to the server // sendUnRelay sends an unrelay message to the server with retry, keyed by chainId
func (pm *PeerMonitor) sendUnRelay(siteID int) error { func (pm *PeerMonitor) sendUnRelay(siteID int) error {
if pm.wsClient == nil { if pm.wsClient == nil {
return fmt.Errorf("websocket client is nil") return fmt.Errorf("websocket client is nil")
} }
err := pm.wsClient.SendMessage("olm/wg/unrelay", map[string]interface{}{ chainId := generateChainId()
stopFunc, _ := pm.wsClient.SendMessageInterval("olm/wg/unrelay", map[string]interface{}{
"siteId": siteID, "siteId": siteID,
}) "chainId": chainId,
if err != nil { }, 2*time.Second, 10)
logger.Error("Failed to send registration message: %v", err)
return err pm.relaySendMu.Lock()
} pm.relaySends[chainId] = stopFunc
logger.Info("Sent unrelay message") pm.relaySendMu.Unlock()
logger.Info("Sent unrelay message for site %d (chain %s)", siteID, chainId)
return nil return nil
} }
// CancelRelaySend stops the interval sender for the given chainId, if one exists.
// If chainId is empty, all active relay senders are stopped.
func (pm *PeerMonitor) CancelRelaySend(chainId string) {
pm.relaySendMu.Lock()
defer pm.relaySendMu.Unlock()
if chainId == "" {
for id, stop := range pm.relaySends {
if stop != nil {
stop()
}
delete(pm.relaySends, id)
}
logger.Info("Cancelled all relay senders")
return
}
if stop, ok := pm.relaySends[chainId]; ok {
stop()
delete(pm.relaySends, chainId)
logger.Info("Cancelled relay sender for chain %s", chainId)
} else {
logger.Warn("CancelRelaySend: no active sender for chain %s", chainId)
}
}
// Stop stops monitoring all peers // Stop stops monitoring all peers
func (pm *PeerMonitor) Stop() { func (pm *PeerMonitor) Stop() {
// Stop holepunch monitor first (outside of mutex to avoid deadlock) // Stop holepunch monitor first (outside of mutex to avoid deadlock)
@@ -679,6 +724,16 @@ func (pm *PeerMonitor) Close() {
// Stop holepunch monitor first (outside of mutex to avoid deadlock) // Stop holepunch monitor first (outside of mutex to avoid deadlock)
pm.stopHolepunchMonitor() pm.stopHolepunchMonitor()
// Stop all pending relay senders
pm.relaySendMu.Lock()
for chainId, stop := range pm.relaySends {
if stop != nil {
stop()
}
delete(pm.relaySends, chainId)
}
pm.relaySendMu.Unlock()
pm.mutex.Lock() pm.mutex.Lock()
defer pm.mutex.Unlock() defer pm.mutex.Unlock()