Compare commits

..

9 Commits
dev ... jit

Author SHA1 Message Date
Owen
c2b5ef96a4 Jit of aliases working 2026-03-12 17:26:46 -07:00
Owen
e326da3d3e Merge branch 'dev' into jit 2026-03-12 16:53:16 -07:00
Owen
22cd02ae15 Alias jit handler 2026-03-11 15:56:51 -07:00
Owen
e2690bcc03 Store site id 2026-03-06 16:19:00 -08:00
Owen
f2d0e6a14c Merge branch 'dev' into jit 2026-03-06 16:08:24 -08:00
Owen
809dbe77de Make chainId in relay message bckwd compat 2026-03-06 15:27:03 -08:00
Owen
c67c2a60a1 Handle canceling sends for relay 2026-03-06 15:15:31 -08:00
Owen
051c0fdfd8 Working jit with chain ids 2026-03-04 17:51:48 -08:00
Owen
e7507e0837 Add api endpoints to jit 2026-03-04 17:01:17 -08:00
11 changed files with 476 additions and 109 deletions

View File

@@ -78,6 +78,13 @@ type MetadataChangeRequest struct {
Postures map[string]any `json:"postures"`
}
// JITConnectionRequest defines the structure for a dynamic Just-In-Time connection request.
// Either SiteID or ResourceID must be provided (but not necessarily both).
type JITConnectionRequest struct {
Site string `json:"site,omitempty"`
Resource string `json:"resource,omitempty"`
}
// API represents the HTTP server and its state
type API struct {
addr string
@@ -92,6 +99,7 @@ type API struct {
onExit func() error
onRebind func() error
onPowerMode func(PowerModeRequest) error
onJITConnect func(JITConnectionRequest) error
statusMu sync.RWMutex
peerStatuses map[int]*PeerStatus
@@ -143,6 +151,7 @@ func (s *API) SetHandlers(
onExit func() error,
onRebind func() error,
onPowerMode func(PowerModeRequest) error,
onJITConnect func(JITConnectionRequest) error,
) {
s.onConnect = onConnect
s.onSwitchOrg = onSwitchOrg
@@ -151,6 +160,7 @@ func (s *API) SetHandlers(
s.onExit = onExit
s.onRebind = onRebind
s.onPowerMode = onPowerMode
s.onJITConnect = onJITConnect
}
// Start starts the HTTP server
@@ -169,6 +179,7 @@ func (s *API) Start() error {
mux.HandleFunc("/health", s.handleHealth)
mux.HandleFunc("/rebind", s.handleRebind)
mux.HandleFunc("/power-mode", s.handlePowerMode)
mux.HandleFunc("/jit-connect", s.handleJITConnect)
s.server = &http.Server{
Handler: mux,
@@ -633,6 +644,54 @@ func (s *API) handleRebind(w http.ResponseWriter, r *http.Request) {
})
}
// handleJITConnect handles the /jit-connect endpoint.
// It initiates a dynamic Just-In-Time connection to a site identified by either
// a site or a resource. Exactly one of the two must be provided.
func (s *API) handleJITConnect(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req JITConnectionRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest)
return
}
// Validate that exactly one of site or resource is provided
if req.Site == "" && req.Resource == "" {
http.Error(w, "Missing required field: either site or resource must be provided", http.StatusBadRequest)
return
}
if req.Site != "" && req.Resource != "" {
http.Error(w, "Ambiguous request: provide either site or resource, not both", http.StatusBadRequest)
return
}
if req.Site != "" {
logger.Info("Received JIT connection request via API: site=%s", req.Site)
} else {
logger.Info("Received JIT connection request via API: resource=%s", req.Resource)
}
if s.onJITConnect != nil {
if err := s.onJITConnect(req); err != nil {
http.Error(w, fmt.Sprintf("JIT connection failed: %v", err), http.StatusInternalServerError)
return
}
} else {
http.Error(w, "JIT connect handler not configured", http.StatusNotImplemented)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusAccepted)
_ = json.NewEncoder(w).Encode(map[string]string{
"status": "JIT connection request accepted",
})
}
// handlePowerMode handles the /power-mode endpoint
// This allows changing the power mode between "normal" and "low"
func (s *API) handlePowerMode(w http.ResponseWriter, r *http.Request) {

View File

@@ -45,6 +45,11 @@ type DNSProxy struct {
tunnelActivePorts map[uint16]bool
tunnelPortsLock sync.Mutex
// jitHandler is called when a local record is resolved for a site that may not be
// connected yet, giving the caller a chance to initiate a JIT connection.
// It is invoked asynchronously so it never blocks DNS resolution.
jitHandler func(siteId int)
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
@@ -384,6 +389,16 @@ func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clie
response = p.checkLocalRecords(msg, question)
}
// If a local A/AAAA record was found, notify the JIT handler so that the owning
// site can be connected on-demand if it is not yet active.
if response != nil && p.jitHandler != nil &&
(question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA) {
if siteId, ok := p.recordStore.GetSiteIdForDomain(question.Name); ok && siteId != 0 {
handler := p.jitHandler
go handler(siteId)
}
}
// If no local records, forward to upstream
if response == nil {
logger.Debug("No local record for %s, forwarding upstream to %v", question.Name, p.upstreamDNS)
@@ -718,11 +733,20 @@ func (p *DNSProxy) runPacketSender() {
}
}
// SetJITHandler registers a callback that is invoked whenever a local DNS record is
// resolved for an A or AAAA query. The siteId identifies which site owns the record.
// The handler is called in its own goroutine so it must be safe to call concurrently.
// Pass nil to disable JIT notifications.
func (p *DNSProxy) SetJITHandler(handler func(siteId int)) {
p.jitHandler = handler
}
// AddDNSRecord adds a DNS record to the local store
// domain should be a domain name (e.g., "example.com" or "example.com.")
// ip should be a valid IPv4 or IPv6 address
func (p *DNSProxy) AddDNSRecord(domain string, ip net.IP) error {
return p.recordStore.AddRecord(domain, ip)
func (p *DNSProxy) AddDNSRecord(domain string, ip net.IP, siteId int) error {
logger.Debug("Adding dns record for domain %s with IP %s (siteId=%d)", domain, ip.String(), siteId)
return p.recordStore.AddRecord(domain, ip, siteId)
}
// RemoveDNSRecord removes a DNS record from the local store

View File

@@ -14,7 +14,7 @@ func TestCheckLocalRecordsNODATAForAAAA(t *testing.T) {
// Add an A record for a domain (no AAAA record)
ip := net.ParseIP("10.0.0.1")
err := proxy.recordStore.AddRecord("myservice.internal", ip)
err := proxy.recordStore.AddRecord("myservice.internal", ip, 0)
if err != nil {
t.Fatalf("Failed to add A record: %v", err)
}
@@ -64,7 +64,7 @@ func TestCheckLocalRecordsNODATAForA(t *testing.T) {
// Add an AAAA record for a domain (no A record)
ip := net.ParseIP("2001:db8::1")
err := proxy.recordStore.AddRecord("ipv6only.internal", ip)
err := proxy.recordStore.AddRecord("ipv6only.internal", ip, 0)
if err != nil {
t.Fatalf("Failed to add AAAA record: %v", err)
}
@@ -113,7 +113,7 @@ func TestCheckLocalRecordsNonExistentDomain(t *testing.T) {
}
// Add a record so the store isn't empty
err := proxy.recordStore.AddRecord("exists.internal", net.ParseIP("10.0.0.1"))
err := proxy.recordStore.AddRecord("exists.internal", net.ParseIP("10.0.0.1"), 0)
if err != nil {
t.Fatalf("Failed to add record: %v", err)
}
@@ -144,7 +144,7 @@ func TestCheckLocalRecordsNODATAWildcard(t *testing.T) {
// Add a wildcard A record (no AAAA)
ip := net.ParseIP("10.0.0.1")
err := proxy.recordStore.AddRecord("*.wildcard.internal", ip)
err := proxy.recordStore.AddRecord("*.wildcard.internal", ip, 0)
if err != nil {
t.Fatalf("Failed to add wildcard A record: %v", err)
}

View File

@@ -20,8 +20,9 @@ const (
// recordSet holds A and AAAA records for a single domain or wildcard pattern
type recordSet struct {
A []net.IP
AAAA []net.IP
A []net.IP
AAAA []net.IP
SiteId int
}
// DNSRecordStore manages local DNS records for A, AAAA, and PTR queries.
@@ -46,8 +47,9 @@ func NewDNSRecordStore() *DNSRecordStore {
// domain should be in FQDN format (e.g., "example.com.")
// domain can contain wildcards: * (0+ chars) and ? (exactly 1 char)
// ip should be a valid IPv4 or IPv6 address
// siteId is the site that owns this alias/domain
// Automatically adds a corresponding PTR record for non-wildcard domains
func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
func (s *DNSRecordStore) AddRecord(domain string, ip net.IP, siteId int) error {
s.mu.Lock()
defer s.mu.Unlock()
@@ -69,12 +71,22 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
}
if m[domain] == nil {
m[domain] = &recordSet{}
m[domain] = &recordSet{SiteId: siteId}
}
rs := m[domain]
if isV4 {
for _, existing := range rs.A {
if existing.Equal(ip) {
return nil
}
}
rs.A = append(rs.A, ip)
} else {
for _, existing := range rs.AAAA {
if existing.Equal(ip) {
return nil
}
}
rs.AAAA = append(rs.AAAA, ip)
}
@@ -85,6 +97,7 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
return nil
}
// AddPTRRecord adds a PTR record mapping an IP address to a domain name
// ip should be a valid IPv4 or IPv6 address
// domain should be in FQDN format (e.g., "example.com.")
@@ -179,6 +192,30 @@ func (s *DNSRecordStore) RemovePTRRecord(ip net.IP) {
delete(s.ptrRecords, ip.String())
}
// GetSiteIdForDomain returns the siteId associated with the given domain.
// It checks exact matches first, then wildcard patterns.
// The second return value is false if the domain is not found in local records.
func (s *DNSRecordStore) GetSiteIdForDomain(domain string) (int, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
domain = strings.ToLower(dns.Fqdn(domain))
// Check exact match first
if rs, exists := s.exact[domain]; exists {
return rs.SiteId, true
}
// Check wildcard matches
for pattern, rs := range s.wildcards {
if matchWildcard(pattern, domain) {
return rs.SiteId, true
}
}
return 0, false
}
// GetRecords returns all IP addresses for a domain and record type.
// The second return value indicates whether the domain exists at all
// (true = domain exists, use NODATA if no records; false = NXDOMAIN).

View File

@@ -170,14 +170,14 @@ func TestDNSRecordStoreWildcard(t *testing.T) {
// Add wildcard records
wildcardIP := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.autoco.internal", wildcardIP)
err := store.AddRecord("*.autoco.internal", wildcardIP, 0)
if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err)
}
// Add exact record
exactIP := net.ParseIP("10.0.0.2")
err = store.AddRecord("exact.autoco.internal", exactIP)
err = store.AddRecord("exact.autoco.internal", exactIP, 0)
if err != nil {
t.Fatalf("Failed to add exact record: %v", err)
}
@@ -221,7 +221,7 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) {
// Add complex wildcard pattern
ip1 := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.host-0?.autoco.internal", ip1)
err := store.AddRecord("*.host-0?.autoco.internal", ip1, 0)
if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err)
}
@@ -262,7 +262,7 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) {
// Add wildcard record
ip := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.autoco.internal", ip)
err := store.AddRecord("*.autoco.internal", ip, 0)
if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err)
}
@@ -297,18 +297,18 @@ func TestDNSRecordStoreMultipleWildcards(t *testing.T) {
ip2 := net.ParseIP("10.0.0.2")
ip3 := net.ParseIP("10.0.0.3")
err := store.AddRecord("*.prod.autoco.internal", ip1)
err := store.AddRecord("*.prod.autoco.internal", ip1, 0)
if err != nil {
t.Fatalf("Failed to add first wildcard: %v", err)
}
err = store.AddRecord("*.dev.autoco.internal", ip2)
err = store.AddRecord("*.dev.autoco.internal", ip2, 0)
if err != nil {
t.Fatalf("Failed to add second wildcard: %v", err)
}
// Add a broader wildcard that matches both
err = store.AddRecord("*.autoco.internal", ip3)
err = store.AddRecord("*.autoco.internal", ip3, 0)
if err != nil {
t.Fatalf("Failed to add third wildcard: %v", err)
}
@@ -337,7 +337,7 @@ func TestDNSRecordStoreIPv6Wildcard(t *testing.T) {
// Add IPv6 wildcard record
ip := net.ParseIP("2001:db8::1")
err := store.AddRecord("*.autoco.internal", ip)
err := store.AddRecord("*.autoco.internal", ip, 0)
if err != nil {
t.Fatalf("Failed to add IPv6 wildcard record: %v", err)
}
@@ -357,7 +357,7 @@ func TestHasRecordWildcard(t *testing.T) {
// Add wildcard record
ip := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.autoco.internal", ip)
err := store.AddRecord("*.autoco.internal", ip, 0)
if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err)
}
@@ -378,7 +378,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
// Add record with mixed case
ip := net.ParseIP("10.0.0.1")
err := store.AddRecord("MyHost.AutoCo.Internal", ip)
err := store.AddRecord("MyHost.AutoCo.Internal", ip, 0)
if err != nil {
t.Fatalf("Failed to add mixed case record: %v", err)
}
@@ -403,7 +403,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
// Test wildcard with mixed case
wildcardIP := net.ParseIP("10.0.0.2")
err = store.AddRecord("*.Example.Com", wildcardIP)
err = store.AddRecord("*.Example.Com", wildcardIP, 0)
if err != nil {
t.Fatalf("Failed to add mixed case wildcard: %v", err)
}
@@ -689,7 +689,7 @@ func TestClearPTRRecords(t *testing.T) {
store.AddPTRRecord(ip2, "host2.example.com.")
// Add some A records too
store.AddRecord("test.example.com.", net.ParseIP("10.0.0.1"))
store.AddRecord("test.example.com.", net.ParseIP("10.0.0.1"), 0)
// Verify PTR records exist
if !store.HasPTRRecord("1.1.168.192.in-addr.arpa.") {
@@ -719,7 +719,7 @@ func TestAutomaticPTRRecordOnAdd(t *testing.T) {
// Add an A record - should automatically add PTR record
domain := "host.example.com."
ip := net.ParseIP("192.168.1.100")
err := store.AddRecord(domain, ip)
err := store.AddRecord(domain, ip, 0)
if err != nil {
t.Fatalf("Failed to add A record: %v", err)
}
@@ -737,7 +737,7 @@ func TestAutomaticPTRRecordOnAdd(t *testing.T) {
// Add AAAA record - should also automatically add PTR record
domain6 := "ipv6host.example.com."
ip6 := net.ParseIP("2001:db8::1")
err = store.AddRecord(domain6, ip6)
err = store.AddRecord(domain6, ip6, 0)
if err != nil {
t.Fatalf("Failed to add AAAA record: %v", err)
}
@@ -759,7 +759,7 @@ func TestAutomaticPTRRecordOnRemove(t *testing.T) {
// Add an A record (with automatic PTR)
domain := "host.example.com."
ip := net.ParseIP("192.168.1.100")
store.AddRecord(domain, ip)
store.AddRecord(domain, ip, 0)
// Verify PTR exists
reverseDomain := "100.1.168.192.in-addr.arpa."
@@ -789,8 +789,8 @@ func TestAutomaticPTRRecordOnRemoveAll(t *testing.T) {
domain := "host.example.com."
ip1 := net.ParseIP("192.168.1.100")
ip2 := net.ParseIP("192.168.1.101")
store.AddRecord(domain, ip1)
store.AddRecord(domain, ip2)
store.AddRecord(domain, ip1, 0)
store.AddRecord(domain, ip2, 0)
// Verify both PTR records exist
reverseDomain1 := "100.1.168.192.in-addr.arpa."
@@ -820,7 +820,7 @@ func TestNoPTRForWildcardRecords(t *testing.T) {
// Add wildcard record - should NOT create PTR record
domain := "*.example.com."
ip := net.ParseIP("192.168.1.100")
err := store.AddRecord(domain, ip)
err := store.AddRecord(domain, ip, 0)
if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err)
}
@@ -844,7 +844,7 @@ func TestPTRRecordOverwrite(t *testing.T) {
// Add first domain with IP
domain1 := "host1.example.com."
ip := net.ParseIP("192.168.1.100")
store.AddRecord(domain1, ip)
store.AddRecord(domain1, ip, 0)
// Verify PTR points to first domain
reverseDomain := "100.1.168.192.in-addr.arpa."
@@ -858,7 +858,7 @@ func TestPTRRecordOverwrite(t *testing.T) {
// Add second domain with same IP - should overwrite PTR
domain2 := "host2.example.com."
store.AddRecord(domain2, ip)
store.AddRecord(domain2, ip, 0)
// Verify PTR now points to second domain (last one added)
result, ok = store.GetPTRRecord(reverseDomain)

View File

@@ -7,6 +7,7 @@ import (
"runtime"
"strconv"
"strings"
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/network"
@@ -174,21 +175,19 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) {
for i := range wgData.Sites {
site := wgData.Sites[i]
if site.PublicKey == "" {
logger.Warn("Skipping site %d (%s): no public key available (site may not be connected)", site.SiteId, site.Name)
continue
if site.PublicKey != "" {
var siteEndpoint string
// here we are going to take the relay endpoint if it exists which means we requested a relay for this peer
if site.RelayEndpoint != "" {
siteEndpoint = site.RelayEndpoint
} else {
siteEndpoint = site.Endpoint
}
o.apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false)
}
var siteEndpoint string
// here we are going to take the relay endpoint if it exists which means we requested a relay for this peer
if site.RelayEndpoint != "" {
siteEndpoint = site.RelayEndpoint
} else {
siteEndpoint = site.Endpoint
}
o.apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false)
// we still call this to add the aliases for jit lookup but we just do that then pass inside. need to skip the above so we dont add to the api
if err := o.peerManager.AddPeer(site); err != nil {
logger.Error("Failed to add peer: %v", err)
return
@@ -203,6 +202,36 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) {
logger.Error("Failed to start DNS proxy: %v", err)
}
// Register JIT handler: when the DNS proxy resolves a local record, check whether
// the owning site is already connected and, if not, initiate a JIT connection.
o.dnsProxy.SetJITHandler(func(siteId int) {
if o.peerManager == nil || o.websocket == nil {
return
}
// Site already has an active peer connection - nothing to do.
if _, exists := o.peerManager.GetPeer(siteId); exists {
return
}
o.peerSendMu.Lock()
defer o.peerSendMu.Unlock()
// A JIT request for this site is already in-flight - avoid duplicate sends.
if _, pending := o.jitPendingSites[siteId]; pending {
return
}
chainId := generateChainId()
logger.Info("DNS-triggered JIT connect for site %d (chainId=%s)", siteId, chainId)
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/init", map[string]interface{}{
"siteId": siteId,
"chainId": chainId,
}, 2*time.Second, 10)
o.stopPeerInits[chainId] = stopFunc
o.jitPendingSites[siteId] = chainId
})
if o.tunnelConfig.OverrideDNS {
// Set up DNS override to use our DNS proxy
if err := dnsOverride.SetupDNSOverride(o.tunnelConfig.InterfaceName, o.dnsProxy.GetProxyIP()); err != nil {

View File

@@ -2,6 +2,7 @@ package olm
import (
"encoding/json"
"fmt"
"time"
"github.com/fosrl/newt/holepunch"
@@ -220,6 +221,7 @@ func (o *Olm) handleSync(msg websocket.WSMessage) {
logger.Info("Sync: Adding new peer for site %d", siteId)
o.holePunchManager.TriggerHolePunch()
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
// // TODO: do we need to send the message to the cloud to add the peer that way?
// if err := o.peerManager.AddPeer(expectedSite); err != nil {
@@ -230,9 +232,17 @@ func (o *Olm) handleSync(msg websocket.WSMessage) {
// add the peer via the server
// this is important because newt needs to get triggered as well to add the peer once the hp is complete
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
"siteId": expectedSite.SiteId,
}, 1*time.Second, 10)
chainId := fmt.Sprintf("sync-%d", expectedSite.SiteId)
o.peerSendMu.Lock()
if stop, ok := o.stopPeerSends[chainId]; ok {
stop()
}
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
"siteId": expectedSite.SiteId,
"chainId": chainId,
}, 2*time.Second, 10)
o.stopPeerSends[chainId] = stopFunc
o.peerSendMu.Unlock()
} else {
// Existing peer - check if update is needed

View File

@@ -2,6 +2,8 @@ package olm
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"net"
"net/http"
@@ -65,7 +67,10 @@ type Olm struct {
stopRegister func()
updateRegister func(newData any)
stopPeerSend func()
stopPeerSends map[string]func()
stopPeerInits map[string]func()
jitPendingSites map[int]string // siteId -> chainId for in-flight JIT requests
peerSendMu sync.Mutex
// WaitGroup to track tunnel lifecycle
tunnelWg sync.WaitGroup
@@ -116,6 +121,13 @@ func (o *Olm) initTunnelInfo(clientID string) error {
return nil
}
// generateChainId generates a random chain ID for tracking peer sender lifecycles.
func generateChainId() string {
b := make([]byte, 8)
_, _ = rand.Read(b)
return hex.EncodeToString(b)
}
func Init(ctx context.Context, config OlmConfig) (*Olm, error) {
logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel))
@@ -166,10 +178,13 @@ func Init(ctx context.Context, config OlmConfig) (*Olm, error) {
apiServer.SetAgent(config.Agent)
newOlm := &Olm{
logFile: logFile,
olmCtx: ctx,
apiServer: apiServer,
olmConfig: config,
logFile: logFile,
olmCtx: ctx,
apiServer: apiServer,
olmConfig: config,
stopPeerSends: make(map[string]func()),
stopPeerInits: make(map[string]func()),
jitPendingSites: make(map[int]string),
}
newOlm.registerAPICallbacks()
@@ -284,6 +299,21 @@ func (o *Olm) registerAPICallbacks() {
logger.Info("Processing power mode change request via API: mode=%s", req.Mode)
return o.SetPowerMode(req.Mode)
},
func(req api.JITConnectionRequest) error {
logger.Info("Processing JIT connect request via API: site=%s resource=%s", req.Site, req.Resource)
chainId := generateChainId()
o.peerSendMu.Lock()
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/init", map[string]interface{}{
"siteId": req.Site,
"resourceId": req.Resource,
"chainId": chainId,
}, 2*time.Second, 10)
o.stopPeerInits[chainId] = stopFunc
o.peerSendMu.Unlock()
return nil
},
)
}
@@ -385,6 +415,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
// Handler for peer handshake - adds exit node to holepunch rotation and notifies server
o.websocket.RegisterHandler("olm/wg/peer/holepunch/site/add", o.handleWgPeerHolepunchAddSite)
o.websocket.RegisterHandler("olm/wg/peer/chain/cancel", o.handleCancelChain)
o.websocket.RegisterHandler("olm/sync", o.handleSync)
o.websocket.OnConnect(func() error {
@@ -427,7 +458,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
"userToken": userToken,
"fingerprint": o.fingerprint,
"postures": o.postures,
}, 1*time.Second, 10)
}, 2*time.Second, 10)
// Invoke onRegistered callback if configured
if o.olmConfig.OnRegistered != nil {
@@ -524,6 +555,23 @@ func (o *Olm) Close() {
o.stopRegister = nil
}
// Stop all pending peer init and send senders before closing websocket
o.peerSendMu.Lock()
for _, stop := range o.stopPeerInits {
if stop != nil {
stop()
}
}
o.stopPeerInits = make(map[string]func())
for _, stop := range o.stopPeerSends {
if stop != nil {
stop()
}
}
o.stopPeerSends = make(map[string]func())
o.jitPendingSites = make(map[int]string)
o.peerSendMu.Unlock()
// send a disconnect message to the cloud to show disconnected
if o.websocket != nil {
o.websocket.SendMessage("olm/disconnecting", map[string]any{})

View File

@@ -20,36 +20,43 @@ func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) {
return
}
if o.stopPeerSend != nil {
o.stopPeerSend()
o.stopPeerSend = nil
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var siteConfig peers.SiteConfig
if err := json.Unmarshal(jsonData, &siteConfig); err != nil {
var siteConfigMsg struct {
peers.SiteConfig
ChainId string `json:"chainId"`
}
if err := json.Unmarshal(jsonData, &siteConfigMsg); err != nil {
logger.Error("Error unmarshaling add data: %v", err)
return
}
if siteConfig.PublicKey == "" {
logger.Warn("Skipping add-peer for site %d (%s): no public key available (site may not be connected)", siteConfig.SiteId, siteConfig.Name)
if siteConfigMsg.ChainId != "" {
o.peerSendMu.Lock()
if stop, ok := o.stopPeerSends[siteConfigMsg.ChainId]; ok {
stop()
delete(o.stopPeerSends, siteConfigMsg.ChainId)
}
o.peerSendMu.Unlock()
}
if siteConfigMsg.PublicKey == "" {
logger.Warn("Skipping add-peer for site %d (%s): no public key available (site may not be connected)", siteConfigMsg.SiteId, siteConfigMsg.Name)
return
}
_ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it
if err := o.peerManager.AddPeer(siteConfig); err != nil {
if err := o.peerManager.AddPeer(siteConfigMsg.SiteConfig); err != nil {
logger.Error("Failed to add peer: %v", err)
return
}
logger.Info("Successfully added peer for site %d", siteConfig.SiteId)
logger.Info("Successfully added peer for site %d", siteConfigMsg.SiteId)
}
func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) {
@@ -169,13 +176,21 @@ func (o *Olm) handleWgPeerRelay(msg websocket.WSMessage) {
return
}
var relayData peers.RelayPeerData
var relayData struct {
peers.RelayPeerData
ChainId string `json:"chainId"`
}
if err := json.Unmarshal(jsonData, &relayData); err != nil {
logger.Error("Error unmarshaling relay data: %v", err)
return
}
if monitor := o.peerManager.GetPeerMonitor(); monitor != nil {
monitor.CancelRelaySend(relayData.ChainId)
}
primaryRelay, err := util.ResolveDomainUpstream(relayData.RelayEndpoint, o.tunnelConfig.PublicDNS)
if err != nil {
logger.Error("Failed to resolve primary relay endpoint: %v", err)
return
@@ -202,13 +217,21 @@ func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) {
return
}
var relayData peers.UnRelayPeerData
var relayData struct {
peers.UnRelayPeerData
ChainId string `json:"chainId"`
}
if err := json.Unmarshal(jsonData, &relayData); err != nil {
logger.Error("Error unmarshaling relay data: %v", err)
return
}
if monitor := o.peerManager.GetPeerMonitor(); monitor != nil {
monitor.CancelRelaySend(relayData.ChainId)
}
primaryRelay, err := util.ResolveDomainUpstream(relayData.Endpoint, o.tunnelConfig.PublicDNS)
if err != nil {
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
}
@@ -235,7 +258,8 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
}
var handshakeData struct {
SiteId int `json:"siteId"`
SiteId int `json:"siteId"`
ChainId string `json:"chainId"`
ExitNode struct {
PublicKey string `json:"publicKey"`
Endpoint string `json:"endpoint"`
@@ -248,6 +272,19 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
return
}
// Stop the peer init sender for this chain, if any
if handshakeData.ChainId != "" {
o.peerSendMu.Lock()
if stop, ok := o.stopPeerInits[handshakeData.ChainId]; ok {
stop()
delete(o.stopPeerInits, handshakeData.ChainId)
}
// If this chain was initiated by a DNS-triggered JIT request, clear the
// pending entry so the site can be re-triggered if needed in the future.
delete(o.jitPendingSites, handshakeData.SiteId)
o.peerSendMu.Unlock()
}
// Get existing peer from PeerManager
_, exists := o.peerManager.GetPeer(handshakeData.SiteId)
if exists {
@@ -278,10 +315,72 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
// Send handshake acknowledgment back to server with retry
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
"siteId": handshakeData.SiteId,
}, 1*time.Second, 10)
// Send handshake acknowledgment back to server with retry, keyed by chainId
chainId := handshakeData.ChainId
if chainId == "" {
chainId = generateChainId()
}
o.peerSendMu.Lock()
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
"siteId": handshakeData.SiteId,
"chainId": chainId,
}, 2*time.Second, 10)
o.stopPeerSends[chainId] = stopFunc
o.peerSendMu.Unlock()
logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint)
}
func (o *Olm) handleCancelChain(msg websocket.WSMessage) {
logger.Debug("Received cancel-chain message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling cancel-chain data: %v", err)
return
}
var cancelData struct {
ChainId string `json:"chainId"`
}
if err := json.Unmarshal(jsonData, &cancelData); err != nil {
logger.Error("Error unmarshaling cancel-chain data: %v", err)
return
}
if cancelData.ChainId == "" {
logger.Warn("Received cancel-chain message with no chainId")
return
}
o.peerSendMu.Lock()
defer o.peerSendMu.Unlock()
found := false
if stop, ok := o.stopPeerInits[cancelData.ChainId]; ok {
stop()
delete(o.stopPeerInits, cancelData.ChainId)
found = true
}
// If this chain was a DNS-triggered JIT request, clear the pending entry so
// the site can be re-triggered on the next DNS lookup.
for siteId, chainId := range o.jitPendingSites {
if chainId == cancelData.ChainId {
delete(o.jitPendingSites, siteId)
break
}
}
if stop, ok := o.stopPeerSends[cancelData.ChainId]; ok {
stop()
delete(o.stopPeerSends, cancelData.ChainId)
found = true
}
if found {
logger.Info("Cancelled chain %s", cancelData.ChainId)
} else {
logger.Warn("Cancel-chain: no active sender found for chain %s", cancelData.ChainId)
}
}

View File

@@ -111,6 +111,19 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
pm.mu.Lock()
defer pm.mu.Unlock()
for _, alias := range siteConfig.Aliases {
address := net.ParseIP(alias.AliasAddress)
if address == nil {
continue
}
pm.dnsProxy.AddDNSRecord(alias.Alias, address, siteConfig.SiteId)
}
if siteConfig.PublicKey == "" {
logger.Debug("Skip adding site %d because no pub key", siteConfig.SiteId)
return nil
}
// build the allowed IPs list from the remote subnets and aliases and add them to the peer
allowedIPs := make([]string, 0, len(siteConfig.RemoteSubnets)+len(siteConfig.Aliases))
allowedIPs = append(allowedIPs, siteConfig.RemoteSubnets...)
@@ -143,13 +156,6 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
if err := network.AddRoutes(siteConfig.RemoteSubnets, pm.interfaceName); err != nil {
logger.Error("Failed to add routes for remote subnets: %v", err)
}
for _, alias := range siteConfig.Aliases {
address := net.ParseIP(alias.AliasAddress)
if address == nil {
continue
}
pm.dnsProxy.AddDNSRecord(alias.Alias, address)
}
monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0]
monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port
@@ -437,7 +443,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
if address == nil {
continue
}
pm.dnsProxy.AddDNSRecord(alias.Alias, address)
pm.dnsProxy.AddDNSRecord(alias.Alias, address, siteConfig.SiteId)
}
pm.peerMonitor.UpdateHolepunchEndpoint(siteConfig.SiteId, siteConfig.Endpoint)
@@ -717,7 +723,7 @@ func (pm *PeerManager) AddAlias(siteId int, alias Alias) error {
address := net.ParseIP(alias.AliasAddress)
if address != nil {
pm.dnsProxy.AddDNSRecord(alias.Alias, address)
pm.dnsProxy.AddDNSRecord(alias.Alias, address, siteId)
}
// Add an allowed IP for the alias

View File

@@ -2,6 +2,8 @@ package monitor
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"net"
"net/netip"
@@ -31,11 +33,15 @@ type PeerMonitor struct {
monitors map[int]*Client
mutex sync.Mutex
running bool
timeout time.Duration
timeout time.Duration
maxAttempts int
wsClient *websocket.Client
publicDNS []string
// Relay sender tracking
relaySends map[string]func()
relaySendMu sync.Mutex
// Netstack fields
middleDev *middleDevice.MiddleDevice
localIP string
@@ -48,13 +54,13 @@ type PeerMonitor struct {
nsWg sync.WaitGroup
// Holepunch testing fields
sharedBind *bind.SharedBind
holepunchTester *holepunch.HolepunchTester
holepunchTimeout time.Duration
holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing
holepunchStatus map[int]bool // siteID -> connected status
holepunchStopChan chan struct{}
holepunchUpdateChan chan struct{}
sharedBind *bind.SharedBind
holepunchTester *holepunch.HolepunchTester
holepunchTimeout time.Duration
holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing
holepunchStatus map[int]bool // siteID -> connected status
holepunchStopChan chan struct{}
holepunchUpdateChan chan struct{}
// Relay tracking fields
relayedPeers map[int]bool // siteID -> whether the peer is currently relayed
@@ -83,6 +89,12 @@ type PeerMonitor struct {
}
// NewPeerMonitor creates a new peer monitor with the given callback
func generateChainId() string {
b := make([]byte, 8)
_, _ = rand.Read(b)
return hex.EncodeToString(b)
}
func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind, apiServer *api.API, publicDNS []string) *PeerMonitor {
ctx, cancel := context.WithCancel(context.Background())
pm := &PeerMonitor{
@@ -101,6 +113,7 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
holepunchEndpoints: make(map[int]string),
holepunchStatus: make(map[int]bool),
relayedPeers: make(map[int]bool),
relaySends: make(map[string]func()),
holepunchMaxAttempts: 3, // Trigger relay after 3 consecutive failures
holepunchFailures: make(map[int]int),
// Rapid initial test settings: complete within ~1.5 seconds
@@ -398,20 +411,23 @@ func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status Connectio
}
}
// sendRelay sends a relay message to the server
// sendRelay sends a relay message to the server with retry, keyed by chainId
func (pm *PeerMonitor) sendRelay(siteID int) error {
if pm.wsClient == nil {
return fmt.Errorf("websocket client is nil")
}
err := pm.wsClient.SendMessage("olm/wg/relay", map[string]interface{}{
"siteId": siteID,
})
if err != nil {
logger.Error("Failed to send registration message: %v", err)
return err
}
logger.Info("Sent relay message")
chainId := generateChainId()
stopFunc, _ := pm.wsClient.SendMessageInterval("olm/wg/relay", map[string]interface{}{
"siteId": siteID,
"chainId": chainId,
}, 2*time.Second, 10)
pm.relaySendMu.Lock()
pm.relaySends[chainId] = stopFunc
pm.relaySendMu.Unlock()
logger.Info("Sent relay message for site %d (chain %s)", siteID, chainId)
return nil
}
@@ -421,23 +437,52 @@ func (pm *PeerMonitor) RequestRelay(siteID int) error {
return pm.sendRelay(siteID)
}
// sendUnRelay sends an unrelay message to the server
// sendUnRelay sends an unrelay message to the server with retry, keyed by chainId
func (pm *PeerMonitor) sendUnRelay(siteID int) error {
if pm.wsClient == nil {
return fmt.Errorf("websocket client is nil")
}
err := pm.wsClient.SendMessage("olm/wg/unrelay", map[string]interface{}{
"siteId": siteID,
})
if err != nil {
logger.Error("Failed to send registration message: %v", err)
return err
}
logger.Info("Sent unrelay message")
chainId := generateChainId()
stopFunc, _ := pm.wsClient.SendMessageInterval("olm/wg/unrelay", map[string]interface{}{
"siteId": siteID,
"chainId": chainId,
}, 2*time.Second, 10)
pm.relaySendMu.Lock()
pm.relaySends[chainId] = stopFunc
pm.relaySendMu.Unlock()
logger.Info("Sent unrelay message for site %d (chain %s)", siteID, chainId)
return nil
}
// CancelRelaySend stops the interval sender for the given chainId, if one exists.
// If chainId is empty, all active relay senders are stopped.
func (pm *PeerMonitor) CancelRelaySend(chainId string) {
pm.relaySendMu.Lock()
defer pm.relaySendMu.Unlock()
if chainId == "" {
for id, stop := range pm.relaySends {
if stop != nil {
stop()
}
delete(pm.relaySends, id)
}
logger.Info("Cancelled all relay senders")
return
}
if stop, ok := pm.relaySends[chainId]; ok {
stop()
delete(pm.relaySends, chainId)
logger.Info("Cancelled relay sender for chain %s", chainId)
} else {
logger.Warn("CancelRelaySend: no active sender for chain %s", chainId)
}
}
// Stop stops monitoring all peers
func (pm *PeerMonitor) Stop() {
// Stop holepunch monitor first (outside of mutex to avoid deadlock)
@@ -679,6 +724,16 @@ func (pm *PeerMonitor) Close() {
// Stop holepunch monitor first (outside of mutex to avoid deadlock)
pm.stopHolepunchMonitor()
// Stop all pending relay senders
pm.relaySendMu.Lock()
for chainId, stop := range pm.relaySends {
if stop != nil {
stop()
}
delete(pm.relaySends, chainId)
}
pm.relaySendMu.Unlock()
pm.mutex.Lock()
defer pm.mutex.Unlock()