From 21b66fbb34264bede6870db575191739c05db347 Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 25 Feb 2026 14:57:56 -0800 Subject: [PATCH 01/15] Update iss --- olm.iss | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/olm.iss b/olm.iss index 1893f8e..4216d88 100644 --- a/olm.iss +++ b/olm.iss @@ -32,7 +32,7 @@ DefaultGroupName={#MyAppName} DisableProgramGroupPage=yes ; Uncomment the following line to run in non administrative install mode (install for current user only). ;PrivilegesRequired=lowest -OutputBaseFilename=mysetup +OutputBaseFilename=olm_windows_installer SolidCompression=yes WizardStyle=modern ; Add this to ensure PATH changes are applied and the system is prompted for a restart if needed @@ -78,7 +78,7 @@ begin Result := True; exit; end; - + // Perform a case-insensitive check to see if the path is already present. // We add semicolons to prevent partial matches (e.g., matching C:\App in C:\App2). if Pos(';' + UpperCase(Path) + ';', ';' + UpperCase(OrigPath) + ';') > 0 then @@ -109,7 +109,7 @@ begin PathList.Delimiter := ';'; PathList.StrictDelimiter := True; PathList.DelimitedText := OrigPath; - + // Find and remove the matching entry (case-insensitive) for I := PathList.Count - 1 downto 0 do begin @@ -119,10 +119,10 @@ begin PathList.Delete(I); end; end; - + // Reconstruct the PATH NewPath := PathList.DelimitedText; - + // Write the new PATH back to the registry if RegWriteExpandStringValue(HKEY_LOCAL_MACHINE, 'SYSTEM\CurrentControlSet\Control\Session Manager\Environment', @@ -145,8 +145,8 @@ begin // Get the application installation path AppPath := ExpandConstant('{app}'); Log('Removing PATH entry for: ' + AppPath); - + // Remove only our path entry from the system PATH RemovePathEntry(AppPath); end; -end; +end; \ No newline at end of file From e7507e08376903d4d98c376f186a5a10a62f2ede Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 4 Mar 2026 17:01:17 -0800 Subject: [PATCH 02/15] Add api endpoints to jit --- api/api.go | 59 +++++++++++++++++++++++++++++++++++++++++++++++++++++ olm/data.go | 1 + olm/olm.go | 11 ++++++++++ 3 files changed, 71 insertions(+) diff --git a/api/api.go b/api/api.go index 047ce08..895140b 100644 --- a/api/api.go +++ b/api/api.go @@ -78,6 +78,13 @@ type MetadataChangeRequest struct { Postures map[string]any `json:"postures"` } +// JITConnectionRequest defines the structure for a dynamic Just-In-Time connection request. +// Either SiteID or ResourceID must be provided (but not necessarily both). +type JITConnectionRequest struct { + Site string `json:"site,omitempty"` + Resource string `json:"resource,omitempty"` +} + // API represents the HTTP server and its state type API struct { addr string @@ -92,6 +99,7 @@ type API struct { onExit func() error onRebind func() error onPowerMode func(PowerModeRequest) error + onJITConnect func(JITConnectionRequest) error statusMu sync.RWMutex peerStatuses map[int]*PeerStatus @@ -143,6 +151,7 @@ func (s *API) SetHandlers( onExit func() error, onRebind func() error, onPowerMode func(PowerModeRequest) error, + onJITConnect func(JITConnectionRequest) error, ) { s.onConnect = onConnect s.onSwitchOrg = onSwitchOrg @@ -151,6 +160,7 @@ func (s *API) SetHandlers( s.onExit = onExit s.onRebind = onRebind s.onPowerMode = onPowerMode + s.onJITConnect = onJITConnect } // Start starts the HTTP server @@ -169,6 +179,7 @@ func (s *API) Start() error { mux.HandleFunc("/health", s.handleHealth) mux.HandleFunc("/rebind", s.handleRebind) mux.HandleFunc("/power-mode", s.handlePowerMode) + mux.HandleFunc("/jit-connect", s.handleJITConnect) s.server = &http.Server{ Handler: mux, @@ -633,6 +644,54 @@ func (s *API) handleRebind(w http.ResponseWriter, r *http.Request) { }) } +// handleJITConnect handles the /jit-connect endpoint. +// It initiates a dynamic Just-In-Time connection to a site identified by either +// a site or a resource. Exactly one of the two must be provided. +func (s *API) handleJITConnect(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var req JITConnectionRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest) + return + } + + // Validate that exactly one of site or resource is provided + if req.Site == "" && req.Resource == "" { + http.Error(w, "Missing required field: either site or resource must be provided", http.StatusBadRequest) + return + } + if req.Site != "" && req.Resource != "" { + http.Error(w, "Ambiguous request: provide either site or resource, not both", http.StatusBadRequest) + return + } + + if req.Site != "" { + logger.Info("Received JIT connection request via API: site=%s", req.Site) + } else { + logger.Info("Received JIT connection request via API: resource=%s", req.Resource) + } + + if s.onJITConnect != nil { + if err := s.onJITConnect(req); err != nil { + http.Error(w, fmt.Sprintf("JIT connection failed: %v", err), http.StatusInternalServerError) + return + } + } else { + http.Error(w, "JIT connect handler not configured", http.StatusNotImplemented) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusAccepted) + _ = json.NewEncoder(w).Encode(map[string]string{ + "status": "JIT connection request accepted", + }) +} + // handlePowerMode handles the /power-mode endpoint // This allows changing the power mode between "normal" and "low" func (s *API) handlePowerMode(w http.ResponseWriter, r *http.Request) { diff --git a/olm/data.go b/olm/data.go index 8bd0997..015931b 100644 --- a/olm/data.go +++ b/olm/data.go @@ -220,6 +220,7 @@ func (o *Olm) handleSync(msg websocket.WSMessage) { logger.Info("Sync: Adding new peer for site %d", siteId) o.holePunchManager.TriggerHolePunch() + o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud // // TODO: do we need to send the message to the cloud to add the peer that way? // if err := o.peerManager.AddPeer(expectedSite); err != nil { diff --git a/olm/olm.go b/olm/olm.go index 9bd41b2..fa32ebd 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -66,6 +66,7 @@ type Olm struct { updateRegister func(newData any) stopPeerSend func() + stopPeerInit func() // WaitGroup to track tunnel lifecycle tunnelWg sync.WaitGroup @@ -284,6 +285,16 @@ func (o *Olm) registerAPICallbacks() { logger.Info("Processing power mode change request via API: mode=%s", req.Mode) return o.SetPowerMode(req.Mode) }, + func(req api.JITConnectionRequest) error { + logger.Info("Processing JIT connect request via API: site=%s resource=%s", req.Site, req.Resource) + + o.stopPeerInit, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/init", map[string]interface{}{ + "siteId": req.Site, + "resourceId": req.Resource, + }, 2*time.Second, 10) + + return nil + }, ) } From 051c0fdfd830cbc78c40d310321788379f3ad06e Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 4 Mar 2026 17:51:48 -0800 Subject: [PATCH 03/15] Working jit with chain ids --- olm/data.go | 15 ++++++-- olm/olm.go | 58 +++++++++++++++++++++++------- olm/peer.go | 101 ++++++++++++++++++++++++++++++++++++++++++++-------- 3 files changed, 145 insertions(+), 29 deletions(-) diff --git a/olm/data.go b/olm/data.go index 015931b..d0e6d5b 100644 --- a/olm/data.go +++ b/olm/data.go @@ -2,6 +2,7 @@ package olm import ( "encoding/json" + "fmt" "time" "github.com/fosrl/newt/holepunch" @@ -231,9 +232,17 @@ func (o *Olm) handleSync(msg websocket.WSMessage) { // add the peer via the server // this is important because newt needs to get triggered as well to add the peer once the hp is complete - o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ - "siteId": expectedSite.SiteId, - }, 1*time.Second, 10) + chainId := fmt.Sprintf("sync-%d", expectedSite.SiteId) + o.peerSendMu.Lock() + if stop, ok := o.stopPeerSends[chainId]; ok { + stop() + } + stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ + "siteId": expectedSite.SiteId, + "chainId": chainId, + }, 2*time.Second, 10) + o.stopPeerSends[chainId] = stopFunc + o.peerSendMu.Unlock() } else { // Existing peer - check if update is needed diff --git a/olm/olm.go b/olm/olm.go index fa32ebd..b2843d2 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -2,6 +2,8 @@ package olm import ( "context" + "crypto/rand" + "encoding/hex" "fmt" "net" "net/http" @@ -65,8 +67,9 @@ type Olm struct { stopRegister func() updateRegister func(newData any) - stopPeerSend func() - stopPeerInit func() + stopPeerSends map[string]func() + stopPeerInits map[string]func() + peerSendMu sync.Mutex // WaitGroup to track tunnel lifecycle tunnelWg sync.WaitGroup @@ -117,6 +120,13 @@ func (o *Olm) initTunnelInfo(clientID string) error { return nil } +// generateChainId generates a random chain ID for tracking peer sender lifecycles. +func generateChainId() string { + b := make([]byte, 8) + _, _ = rand.Read(b) + return hex.EncodeToString(b) +} + func Init(ctx context.Context, config OlmConfig) (*Olm, error) { logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel)) @@ -167,10 +177,12 @@ func Init(ctx context.Context, config OlmConfig) (*Olm, error) { apiServer.SetAgent(config.Agent) newOlm := &Olm{ - logFile: logFile, - olmCtx: ctx, - apiServer: apiServer, - olmConfig: config, + logFile: logFile, + olmCtx: ctx, + apiServer: apiServer, + olmConfig: config, + stopPeerSends: make(map[string]func()), + stopPeerInits: make(map[string]func()), } newOlm.registerAPICallbacks() @@ -287,12 +299,17 @@ func (o *Olm) registerAPICallbacks() { }, func(req api.JITConnectionRequest) error { logger.Info("Processing JIT connect request via API: site=%s resource=%s", req.Site, req.Resource) - - o.stopPeerInit, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/init", map[string]interface{}{ - "siteId": req.Site, + + chainId := generateChainId() + o.peerSendMu.Lock() + stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/init", map[string]interface{}{ + "siteId": req.Site, "resourceId": req.Resource, - }, 2*time.Second, 10) - + "chainId": chainId, + }, 2*time.Second, 10) + o.stopPeerInits[chainId] = stopFunc + o.peerSendMu.Unlock() + return nil }, ) @@ -389,6 +406,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) { // Handler for peer handshake - adds exit node to holepunch rotation and notifies server o.websocket.RegisterHandler("olm/wg/peer/holepunch/site/add", o.handleWgPeerHolepunchAddSite) + o.websocket.RegisterHandler("olm/wg/peer/chain/cancel", o.handleCancelChain) o.websocket.RegisterHandler("olm/sync", o.handleSync) o.websocket.OnConnect(func() error { @@ -431,7 +449,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) { "userToken": userToken, "fingerprint": o.fingerprint, "postures": o.postures, - }, 1*time.Second, 10) + }, 2*time.Second, 10) // Invoke onRegistered callback if configured if o.olmConfig.OnRegistered != nil { @@ -528,6 +546,22 @@ func (o *Olm) Close() { o.stopRegister = nil } + // Stop all pending peer init and send senders before closing websocket + o.peerSendMu.Lock() + for _, stop := range o.stopPeerInits { + if stop != nil { + stop() + } + } + o.stopPeerInits = make(map[string]func()) + for _, stop := range o.stopPeerSends { + if stop != nil { + stop() + } + } + o.stopPeerSends = make(map[string]func()) + o.peerSendMu.Unlock() + // send a disconnect message to the cloud to show disconnected if o.websocket != nil { o.websocket.SendMessage("olm/disconnecting", map[string]any{}) diff --git a/olm/peer.go b/olm/peer.go index 8007272..1937934 100644 --- a/olm/peer.go +++ b/olm/peer.go @@ -20,31 +20,39 @@ func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) { return } - if o.stopPeerSend != nil { - o.stopPeerSend() - o.stopPeerSend = nil - } - jsonData, err := json.Marshal(msg.Data) if err != nil { logger.Error("Error marshaling data: %v", err) return } - var siteConfig peers.SiteConfig - if err := json.Unmarshal(jsonData, &siteConfig); err != nil { + var siteConfigMsg struct { + peers.SiteConfig + ChainId string `json:"chainId"` + } + if err := json.Unmarshal(jsonData, &siteConfigMsg); err != nil { logger.Error("Error unmarshaling add data: %v", err) return } + if siteConfigMsg.ChainId != "" { + o.peerSendMu.Lock() + if stop, ok := o.stopPeerSends[siteConfigMsg.ChainId]; ok { + stop() + delete(o.stopPeerSends, siteConfigMsg.ChainId) + } + o.peerSendMu.Unlock() + } + _ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it - if err := o.peerManager.AddPeer(siteConfig); err != nil { + if err := o.peerManager.AddPeer(siteConfigMsg.SiteConfig); err != nil { logger.Error("Failed to add peer: %v", err) return } - logger.Info("Successfully added peer for site %d", siteConfig.SiteId) + + logger.Info("Successfully added peer for site %d", siteConfigMsg.SiteId) } func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) { @@ -230,7 +238,8 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) { } var handshakeData struct { - SiteId int `json:"siteId"` + SiteId int `json:"siteId"` + ChainId string `json:"chainId"` ExitNode struct { PublicKey string `json:"publicKey"` Endpoint string `json:"endpoint"` @@ -242,6 +251,16 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) { logger.Error("Error unmarshaling handshake data: %v", err) return } + + // Stop the peer init sender for this chain, if any + if handshakeData.ChainId != "" { + o.peerSendMu.Lock() + if stop, ok := o.stopPeerInits[handshakeData.ChainId]; ok { + stop() + delete(o.stopPeerInits, handshakeData.ChainId) + } + o.peerSendMu.Unlock() + } // Get existing peer from PeerManager _, exists := o.peerManager.GetPeer(handshakeData.SiteId) @@ -273,10 +292,64 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) { o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud - // Send handshake acknowledgment back to server with retry - o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ - "siteId": handshakeData.SiteId, - }, 1*time.Second, 10) + // Send handshake acknowledgment back to server with retry, keyed by chainId + chainId := handshakeData.ChainId + if chainId == "" { + chainId = generateChainId() + } + o.peerSendMu.Lock() + stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ + "siteId": handshakeData.SiteId, + "chainId": chainId, + }, 2*time.Second, 10) + o.stopPeerSends[chainId] = stopFunc + o.peerSendMu.Unlock() logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint) } + +func (o *Olm) handleCancelChain(msg websocket.WSMessage) { + logger.Debug("Received cancel-chain message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling cancel-chain data: %v", err) + return + } + + var cancelData struct { + ChainId string `json:"chainId"` + } + if err := json.Unmarshal(jsonData, &cancelData); err != nil { + logger.Error("Error unmarshaling cancel-chain data: %v", err) + return + } + + if cancelData.ChainId == "" { + logger.Warn("Received cancel-chain message with no chainId") + return + } + + o.peerSendMu.Lock() + defer o.peerSendMu.Unlock() + + found := false + + if stop, ok := o.stopPeerInits[cancelData.ChainId]; ok { + stop() + delete(o.stopPeerInits, cancelData.ChainId) + found = true + } + + if stop, ok := o.stopPeerSends[cancelData.ChainId]; ok { + stop() + delete(o.stopPeerSends, cancelData.ChainId) + found = true + } + + if found { + logger.Info("Cancelled chain %s", cancelData.ChainId) + } else { + logger.Warn("Cancel-chain: no active sender found for chain %s", cancelData.ChainId) + } +} From c67c2a60a1e2037c4d3c228b8c5fb9c6b4355001 Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 6 Mar 2026 15:15:31 -0800 Subject: [PATCH 04/15] Handle canceling sends for relay --- olm/peer.go | 22 ++++++++++- peers/monitor/monitor.go | 79 +++++++++++++++++++++++++++++++--------- 2 files changed, 81 insertions(+), 20 deletions(-) diff --git a/olm/peer.go b/olm/peer.go index 1937934..c611921 100644 --- a/olm/peer.go +++ b/olm/peer.go @@ -172,12 +172,21 @@ func (o *Olm) handleWgPeerRelay(msg websocket.WSMessage) { return } - var relayData peers.RelayPeerData + var relayData struct { + peers.RelayPeerData + ChainId string `json:"chainId"` + } if err := json.Unmarshal(jsonData, &relayData); err != nil { logger.Error("Error unmarshaling relay data: %v", err) return } + if relayData.ChainId != "" { + if monitor := o.peerManager.GetPeerMonitor(); monitor != nil { + monitor.CancelRelaySend(relayData.ChainId) + } + } + primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint) if err != nil { logger.Error("Failed to resolve primary relay endpoint: %v", err) @@ -205,12 +214,21 @@ func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) { return } - var relayData peers.UnRelayPeerData + var relayData struct { + peers.UnRelayPeerData + ChainId string `json:"chainId"` + } if err := json.Unmarshal(jsonData, &relayData); err != nil { logger.Error("Error unmarshaling relay data: %v", err) return } + if relayData.ChainId != "" { + if monitor := o.peerManager.GetPeerMonitor(); monitor != nil { + monitor.CancelRelaySend(relayData.ChainId) + } + } + primaryRelay, err := util.ResolveDomain(relayData.Endpoint) if err != nil { logger.Warn("Failed to resolve primary relay endpoint: %v", err) diff --git a/peers/monitor/monitor.go b/peers/monitor/monitor.go index 28d92ef..1296fef 100644 --- a/peers/monitor/monitor.go +++ b/peers/monitor/monitor.go @@ -2,6 +2,8 @@ package monitor import ( "context" + "crypto/rand" + "encoding/hex" "fmt" "net" "net/netip" @@ -35,6 +37,10 @@ type PeerMonitor struct { maxAttempts int wsClient *websocket.Client + // Relay sender tracking + relaySends map[string]func() + relaySendMu sync.Mutex + // Netstack fields middleDev *middleDevice.MiddleDevice localIP string @@ -82,6 +88,12 @@ type PeerMonitor struct { } // NewPeerMonitor creates a new peer monitor with the given callback +func generateChainId() string { + b := make([]byte, 8) + _, _ = rand.Read(b) + return hex.EncodeToString(b) +} + func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind, apiServer *api.API) *PeerMonitor { ctx, cancel := context.WithCancel(context.Background()) pm := &PeerMonitor{ @@ -99,6 +111,7 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe holepunchEndpoints: make(map[int]string), holepunchStatus: make(map[int]bool), relayedPeers: make(map[int]bool), + relaySends: make(map[string]func()), holepunchMaxAttempts: 3, // Trigger relay after 3 consecutive failures holepunchFailures: make(map[int]int), // Rapid initial test settings: complete within ~1.5 seconds @@ -396,20 +409,23 @@ func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status Connectio } } -// sendRelay sends a relay message to the server +// sendRelay sends a relay message to the server with retry, keyed by chainId func (pm *PeerMonitor) sendRelay(siteID int) error { if pm.wsClient == nil { return fmt.Errorf("websocket client is nil") } - err := pm.wsClient.SendMessage("olm/wg/relay", map[string]interface{}{ - "siteId": siteID, - }) - if err != nil { - logger.Error("Failed to send registration message: %v", err) - return err - } - logger.Info("Sent relay message") + chainId := generateChainId() + stopFunc, _ := pm.wsClient.SendMessageInterval("olm/wg/relay", map[string]interface{}{ + "siteId": siteID, + "chainId": chainId, + }, 2*time.Second, 10) + + pm.relaySendMu.Lock() + pm.relaySends[chainId] = stopFunc + pm.relaySendMu.Unlock() + + logger.Info("Sent relay message for site %d (chain %s)", siteID, chainId) return nil } @@ -419,23 +435,40 @@ func (pm *PeerMonitor) RequestRelay(siteID int) error { return pm.sendRelay(siteID) } -// sendUnRelay sends an unrelay message to the server +// sendUnRelay sends an unrelay message to the server with retry, keyed by chainId func (pm *PeerMonitor) sendUnRelay(siteID int) error { if pm.wsClient == nil { return fmt.Errorf("websocket client is nil") } - err := pm.wsClient.SendMessage("olm/wg/unrelay", map[string]interface{}{ - "siteId": siteID, - }) - if err != nil { - logger.Error("Failed to send registration message: %v", err) - return err - } - logger.Info("Sent unrelay message") + chainId := generateChainId() + stopFunc, _ := pm.wsClient.SendMessageInterval("olm/wg/unrelay", map[string]interface{}{ + "siteId": siteID, + "chainId": chainId, + }, 2*time.Second, 10) + + pm.relaySendMu.Lock() + pm.relaySends[chainId] = stopFunc + pm.relaySendMu.Unlock() + + logger.Info("Sent unrelay message for site %d (chain %s)", siteID, chainId) return nil } +// CancelRelaySend stops the interval sender for the given chainId, if one exists. +func (pm *PeerMonitor) CancelRelaySend(chainId string) { + pm.relaySendMu.Lock() + defer pm.relaySendMu.Unlock() + + if stop, ok := pm.relaySends[chainId]; ok { + stop() + delete(pm.relaySends, chainId) + logger.Info("Cancelled relay sender for chain %s", chainId) + } else { + logger.Warn("CancelRelaySend: no active sender for chain %s", chainId) + } +} + // Stop stops monitoring all peers func (pm *PeerMonitor) Stop() { // Stop holepunch monitor first (outside of mutex to avoid deadlock) @@ -677,6 +710,16 @@ func (pm *PeerMonitor) Close() { // Stop holepunch monitor first (outside of mutex to avoid deadlock) pm.stopHolepunchMonitor() + // Stop all pending relay senders + pm.relaySendMu.Lock() + for chainId, stop := range pm.relaySends { + if stop != nil { + stop() + } + delete(pm.relaySends, chainId) + } + pm.relaySendMu.Unlock() + pm.mutex.Lock() defer pm.mutex.Unlock() From 809dbe77de0042d038f86b2a5fbedb53a3af8e61 Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 6 Mar 2026 15:27:03 -0800 Subject: [PATCH 05/15] Make chainId in relay message bckwd compat --- olm/peer.go | 21 ++++++++------------- peers/monitor/monitor.go | 34 +++++++++++++++++++++++----------- 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/olm/peer.go b/olm/peer.go index c611921..9f02bb2 100644 --- a/olm/peer.go +++ b/olm/peer.go @@ -51,7 +51,6 @@ func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) { return } - logger.Info("Successfully added peer for site %d", siteConfigMsg.SiteId) } @@ -181,10 +180,8 @@ func (o *Olm) handleWgPeerRelay(msg websocket.WSMessage) { return } - if relayData.ChainId != "" { - if monitor := o.peerManager.GetPeerMonitor(); monitor != nil { - monitor.CancelRelaySend(relayData.ChainId) - } + if monitor := o.peerManager.GetPeerMonitor(); monitor != nil { + monitor.CancelRelaySend(relayData.ChainId) } primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint) @@ -223,10 +220,8 @@ func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) { return } - if relayData.ChainId != "" { - if monitor := o.peerManager.GetPeerMonitor(); monitor != nil { - monitor.CancelRelaySend(relayData.ChainId) - } + if monitor := o.peerManager.GetPeerMonitor(); monitor != nil { + monitor.CancelRelaySend(relayData.ChainId) } primaryRelay, err := util.ResolveDomain(relayData.Endpoint) @@ -256,8 +251,8 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) { } var handshakeData struct { - SiteId int `json:"siteId"` - ChainId string `json:"chainId"` + SiteId int `json:"siteId"` + ChainId string `json:"chainId"` ExitNode struct { PublicKey string `json:"publicKey"` Endpoint string `json:"endpoint"` @@ -269,7 +264,7 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) { logger.Error("Error unmarshaling handshake data: %v", err) return } - + // Stop the peer init sender for this chain, if any if handshakeData.ChainId != "" { o.peerSendMu.Lock() @@ -278,7 +273,7 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) { delete(o.stopPeerInits, handshakeData.ChainId) } o.peerSendMu.Unlock() - } + } // Get existing peer from PeerManager _, exists := o.peerManager.GetPeer(handshakeData.SiteId) diff --git a/peers/monitor/monitor.go b/peers/monitor/monitor.go index 1296fef..6b0d557 100644 --- a/peers/monitor/monitor.go +++ b/peers/monitor/monitor.go @@ -33,13 +33,13 @@ type PeerMonitor struct { monitors map[int]*Client mutex sync.Mutex running bool - timeout time.Duration + timeout time.Duration maxAttempts int wsClient *websocket.Client // Relay sender tracking - relaySends map[string]func() - relaySendMu sync.Mutex + relaySends map[string]func() + relaySendMu sync.Mutex // Netstack fields middleDev *middleDevice.MiddleDevice @@ -53,13 +53,13 @@ type PeerMonitor struct { nsWg sync.WaitGroup // Holepunch testing fields - sharedBind *bind.SharedBind - holepunchTester *holepunch.HolepunchTester - holepunchTimeout time.Duration - holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing - holepunchStatus map[int]bool // siteID -> connected status - holepunchStopChan chan struct{} - holepunchUpdateChan chan struct{} + sharedBind *bind.SharedBind + holepunchTester *holepunch.HolepunchTester + holepunchTimeout time.Duration + holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing + holepunchStatus map[int]bool // siteID -> connected status + holepunchStopChan chan struct{} + holepunchUpdateChan chan struct{} // Relay tracking fields relayedPeers map[int]bool // siteID -> whether the peer is currently relayed @@ -456,10 +456,22 @@ func (pm *PeerMonitor) sendUnRelay(siteID int) error { } // CancelRelaySend stops the interval sender for the given chainId, if one exists. +// If chainId is empty, all active relay senders are stopped. func (pm *PeerMonitor) CancelRelaySend(chainId string) { pm.relaySendMu.Lock() defer pm.relaySendMu.Unlock() + if chainId == "" { + for id, stop := range pm.relaySends { + if stop != nil { + stop() + } + delete(pm.relaySends, id) + } + logger.Info("Cancelled all relay senders") + return + } + if stop, ok := pm.relaySends[chainId]; ok { stop() delete(pm.relaySends, chainId) @@ -567,7 +579,7 @@ func (pm *PeerMonitor) runHolepunchMonitor() { pm.holepunchCurrentInterval = pm.holepunchMinInterval currentInterval := pm.holepunchCurrentInterval pm.mutex.Unlock() - + timer.Reset(currentInterval) logger.Debug("Holepunch monitor interval updated, reset to %v", currentInterval) case <-timer.C: From 5ca48258007eac6be7ee0777aa091d2731ff9a6e Mon Sep 17 00:00:00 2001 From: Laurence Date: Thu, 26 Feb 2026 11:15:42 +0000 Subject: [PATCH 06/15] refactor(dns): trie + unified record set for DNSRecordStore - Replace four maps (aRecords, aaaaRecords, aWildcards, aaaaWildcards) with a label trie for exact lookups and a single wildcards map - Store one recordSet (A + AAAA) per domain/pattern instead of separate A and AAAA maps - Exact lookups O(labels); PTR unchanged (map); API and behaviour unchanged --- dns/dns_records.go | 358 +++++++++++++++++++++++---------------------- 1 file changed, 185 insertions(+), 173 deletions(-) diff --git a/dns/dns_records.go b/dns/dns_records.go index 199b94b..5c62043 100644 --- a/dns/dns_records.go +++ b/dns/dns_records.go @@ -18,24 +18,49 @@ const ( RecordTypePTR RecordType = RecordType(dns.TypePTR) ) -// DNSRecordStore manages local DNS records for A, AAAA, and PTR queries +// recordSet holds A and AAAA records for a single domain or wildcard pattern +type recordSet struct { + A []net.IP + AAAA []net.IP +} + +// domainTrieNode is a node in the trie for exact domain lookups (no wildcards in path) +type domainTrieNode struct { + children map[string]*domainTrieNode + data *recordSet +} + +// DNSRecordStore manages local DNS records for A, AAAA, and PTR queries. +// Exact domains are stored in a trie for O(label count) lookup; wildcard patterns +// are in a separate map. Each domain/pattern has a single recordSet (A + AAAA). type DNSRecordStore struct { - mu sync.RWMutex - aRecords map[string][]net.IP // domain -> list of IPv4 addresses - aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses - aWildcards map[string][]net.IP // wildcard pattern -> list of IPv4 addresses - aaaaWildcards map[string][]net.IP // wildcard pattern -> list of IPv6 addresses - ptrRecords map[string]string // IP address string -> domain name + mu sync.RWMutex + root *domainTrieNode // trie root for exact lookups + wildcards map[string]*recordSet // wildcard pattern -> A/AAAA records + ptrRecords map[string]string // IP address string -> domain name +} + +// domainToPath converts a FQDN to a trie path (reversed labels, e.g. "host.internal." -> ["internal", "host"]) +func domainToPath(domain string) []string { + domain = strings.ToLower(dns.Fqdn(domain)) + domain = strings.TrimSuffix(domain, ".") + if domain == "" { + return nil + } + labels := strings.Split(domain, ".") + path := make([]string, 0, len(labels)) + for i := len(labels) - 1; i >= 0; i-- { + path = append(path, labels[i]) + } + return path } // NewDNSRecordStore creates a new DNS record store func NewDNSRecordStore() *DNSRecordStore { return &DNSRecordStore{ - aRecords: make(map[string][]net.IP), - aaaaRecords: make(map[string][]net.IP), - aWildcards: make(map[string][]net.IP), - aaaaWildcards: make(map[string][]net.IP), - ptrRecords: make(map[string]string), + root: &domainTrieNode{children: make(map[string]*domainTrieNode)}, + wildcards: make(map[string]*recordSet), + ptrRecords: make(map[string]string), } } @@ -48,39 +73,47 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error { s.mu.Lock() defer s.mu.Unlock() - // Ensure domain ends with a dot (FQDN format) if len(domain) == 0 || domain[len(domain)-1] != '.' { domain = domain + "." } - - // Normalize domain to lowercase FQDN domain = strings.ToLower(dns.Fqdn(domain)) - - // Check if domain contains wildcards isWildcard := strings.ContainsAny(domain, "*?") - if ip.To4() != nil { - // IPv4 address - if isWildcard { - s.aWildcards[domain] = append(s.aWildcards[domain], ip) - } else { - s.aRecords[domain] = append(s.aRecords[domain], ip) - // Automatically add PTR record for non-wildcard domains - s.ptrRecords[ip.String()] = domain - } - } else if ip.To16() != nil { - // IPv6 address - if isWildcard { - s.aaaaWildcards[domain] = append(s.aaaaWildcards[domain], ip) - } else { - s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip) - // Automatically add PTR record for non-wildcard domains - s.ptrRecords[ip.String()] = domain - } - } else { + isV4 := ip.To4() != nil + if !isV4 && ip.To16() == nil { return &net.ParseError{Type: "IP address", Text: ip.String()} } + if isWildcard { + if s.wildcards[domain] == nil { + s.wildcards[domain] = &recordSet{} + } + rs := s.wildcards[domain] + if isV4 { + rs.A = append(rs.A, ip) + } else { + rs.AAAA = append(rs.AAAA, ip) + } + return nil + } + + path := domainToPath(domain) + node := s.root + for _, label := range path { + if node.children[label] == nil { + node.children[label] = &domainTrieNode{children: make(map[string]*domainTrieNode)} + } + node = node.children[label] + } + if node.data == nil { + node.data = &recordSet{} + } + if isV4 { + node.data.A = append(node.data.A, ip) + } else { + node.data.AAAA = append(node.data.AAAA, ip) + } + s.ptrRecords[ip.String()] = domain return nil } @@ -112,89 +145,74 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) { s.mu.Lock() defer s.mu.Unlock() - // Ensure domain ends with a dot (FQDN format) if len(domain) == 0 || domain[len(domain)-1] != '.' { domain = domain + "." } - - // Normalize domain to lowercase FQDN domain = strings.ToLower(dns.Fqdn(domain)) - - // Check if domain contains wildcards isWildcard := strings.ContainsAny(domain, "*?") - if ip == nil { - // Remove all records for this domain - if isWildcard { - delete(s.aWildcards, domain) - delete(s.aaaaWildcards, domain) + if isWildcard { + if ip == nil { + delete(s.wildcards, domain) + return + } + rs := s.wildcards[domain] + if rs == nil { + return + } + if ip.To4() != nil { + rs.A = removeIP(rs.A, ip) } else { - // For non-wildcard domains, remove PTR records for all IPs - if ips, ok := s.aRecords[domain]; ok { - for _, ipAddr := range ips { - // Only remove PTR if it points to this domain - if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain { - delete(s.ptrRecords, ipAddr.String()) - } - } - } - if ips, ok := s.aaaaRecords[domain]; ok { - for _, ipAddr := range ips { - // Only remove PTR if it points to this domain - if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain { - delete(s.ptrRecords, ipAddr.String()) - } - } - } - delete(s.aRecords, domain) - delete(s.aaaaRecords, domain) + rs.AAAA = removeIP(rs.AAAA, ip) + } + if len(rs.A) == 0 && len(rs.AAAA) == 0 { + delete(s.wildcards, domain) } return } + // Exact domain: find trie node + path := domainToPath(domain) + node := s.root + for _, label := range path { + node = node.children[label] + if node == nil { + return + } + } + if node.data == nil { + return + } + + if ip == nil { + for _, ipAddr := range node.data.A { + if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain { + delete(s.ptrRecords, ipAddr.String()) + } + } + for _, ipAddr := range node.data.AAAA { + if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain { + delete(s.ptrRecords, ipAddr.String()) + } + } + node.data = nil + return + } + if ip.To4() != nil { - // Remove specific IPv4 address - if isWildcard { - if ips, ok := s.aWildcards[domain]; ok { - s.aWildcards[domain] = removeIP(ips, ip) - if len(s.aWildcards[domain]) == 0 { - delete(s.aWildcards, domain) - } - } - } else { - if ips, ok := s.aRecords[domain]; ok { - s.aRecords[domain] = removeIP(ips, ip) - if len(s.aRecords[domain]) == 0 { - delete(s.aRecords, domain) - } - // Automatically remove PTR record if it points to this domain - if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain { - delete(s.ptrRecords, ip.String()) - } - } + node.data.A = removeIP(node.data.A, ip) + if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain { + delete(s.ptrRecords, ip.String()) } - } else if ip.To16() != nil { - // Remove specific IPv6 address - if isWildcard { - if ips, ok := s.aaaaWildcards[domain]; ok { - s.aaaaWildcards[domain] = removeIP(ips, ip) - if len(s.aaaaWildcards[domain]) == 0 { - delete(s.aaaaWildcards, domain) - } - } - } else { - if ips, ok := s.aaaaRecords[domain]; ok { - s.aaaaRecords[domain] = removeIP(ips, ip) - if len(s.aaaaRecords[domain]) == 0 { - delete(s.aaaaRecords, domain) - } - // Automatically remove PTR record if it points to this domain - if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain { - delete(s.ptrRecords, ip.String()) - } - } + } else { + node.data.AAAA = removeIP(node.data.AAAA, ip) + if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain { + delete(s.ptrRecords, ip.String()) } } + if len(node.data.A) == 0 && len(node.data.AAAA) == 0 { + node.data = nil + } } // RemovePTRRecord removes a PTR record for an IP address @@ -206,60 +224,54 @@ func (s *DNSRecordStore) RemovePTRRecord(ip net.IP) { } // GetRecords returns all IP addresses for a domain and record type -// First checks for exact matches, then checks wildcard patterns +// First checks for exact match in the trie, then wildcard patterns func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP { s.mu.RLock() defer s.mu.RUnlock() - // Normalize domain to lowercase FQDN domain = strings.ToLower(dns.Fqdn(domain)) + path := domainToPath(domain) - var records []net.IP - switch recordType { - case RecordTypeA: - // Check exact match first - if ips, ok := s.aRecords[domain]; ok { - // Return a copy to prevent external modifications - records = make([]net.IP, len(ips)) - copy(records, ips) - return records + // Exact match: walk trie + node := s.root + for _, label := range path { + node = node.children[label] + if node == nil { + break } - // Check wildcard patterns - for pattern, ips := range s.aWildcards { - if matchWildcard(pattern, domain) { - records = append(records, ips...) - } + } + if node != nil && node.data != nil { + var ips []net.IP + if recordType == RecordTypeA { + ips = node.data.A + } else { + ips = node.data.AAAA } - if len(records) > 0 { - // Return a copy - result := make([]net.IP, len(records)) - copy(result, records) - return result - } - - case RecordTypeAAAA: - // Check exact match first - if ips, ok := s.aaaaRecords[domain]; ok { - // Return a copy to prevent external modifications - records = make([]net.IP, len(ips)) - copy(records, ips) - return records - } - // Check wildcard patterns - for pattern, ips := range s.aaaaWildcards { - if matchWildcard(pattern, domain) { - records = append(records, ips...) - } - } - if len(records) > 0 { - // Return a copy - result := make([]net.IP, len(records)) - copy(result, records) - return result + if len(ips) > 0 { + out := make([]net.IP, len(ips)) + copy(out, ips) + return out } } - return records + // Wildcard match + var records []net.IP + for pattern, rs := range s.wildcards { + if !matchWildcard(pattern, domain) { + continue + } + if recordType == RecordTypeA { + records = append(records, rs.A...) + } else { + records = append(records, rs.AAAA...) + } + } + if len(records) == 0 { + return nil + } + out := make([]net.IP, len(records)) + copy(out, records) + return out } // GetPTRRecord returns the domain name for a PTR record query @@ -283,39 +295,41 @@ func (s *DNSRecordStore) GetPTRRecord(domain string) (string, bool) { } // HasRecord checks if a domain has any records of the specified type -// Checks both exact matches and wildcard patterns +// Checks both exact matches (trie) and wildcard patterns func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool { s.mu.RLock() defer s.mu.RUnlock() - // Normalize domain to lowercase FQDN domain = strings.ToLower(dns.Fqdn(domain)) + path := domainToPath(domain) - switch recordType { - case RecordTypeA: - // Check exact match - if _, ok := s.aRecords[domain]; ok { + node := s.root + for _, label := range path { + node = node.children[label] + if node == nil { + break + } + } + if node != nil && node.data != nil { + if recordType == RecordTypeA && len(node.data.A) > 0 { return true } - // Check wildcard patterns - for pattern := range s.aWildcards { - if matchWildcard(pattern, domain) { - return true - } - } - case RecordTypeAAAA: - // Check exact match - if _, ok := s.aaaaRecords[domain]; ok { + if recordType == RecordTypeAAAA && len(node.data.AAAA) > 0 { return true } - // Check wildcard patterns - for pattern := range s.aaaaWildcards { - if matchWildcard(pattern, domain) { - return true - } - } } + for pattern, rs := range s.wildcards { + if !matchWildcard(pattern, domain) { + continue + } + if recordType == RecordTypeA && len(rs.A) > 0 { + return true + } + if recordType == RecordTypeAAAA && len(rs.AAAA) > 0 { + return true + } + } return false } @@ -339,10 +353,8 @@ func (s *DNSRecordStore) Clear() { s.mu.Lock() defer s.mu.Unlock() - s.aRecords = make(map[string][]net.IP) - s.aaaaRecords = make(map[string][]net.IP) - s.aWildcards = make(map[string][]net.IP) - s.aaaaWildcards = make(map[string][]net.IP) + s.root = &domainTrieNode{children: make(map[string]*domainTrieNode)} + s.wildcards = make(map[string]*recordSet) s.ptrRecords = make(map[string]string) } @@ -494,4 +506,4 @@ func IPToReverseDNS(ip net.IP) string { } return "" -} \ No newline at end of file +} From 9ae49e36d5c341266d6eff74d948af99faabbd95 Mon Sep 17 00:00:00 2001 From: Laurence Date: Sat, 28 Feb 2026 10:03:09 +0000 Subject: [PATCH 07/15] refactor(dns): simplify DNSRecordStore from trie to map Replace trie-based domain lookup with simple map for O(1) lookups. Add exists boolean to GetRecords for proper NODATA vs NXDOMAIN responses. --- dns/dns_proxy.go | 14 +-- dns/dns_records.go | 199 +++++++++++++++------------------------- dns/dns_records_test.go | 60 ++++++++---- 3 files changed, 125 insertions(+), 148 deletions(-) diff --git a/dns/dns_proxy.go b/dns/dns_proxy.go index 986e847..27770e4 100644 --- a/dns/dns_proxy.go +++ b/dns/dns_proxy.go @@ -447,19 +447,20 @@ func (p *DNSProxy) checkLocalRecords(query *dns.Msg, question dns.Question) *dns return nil } - ips := p.recordStore.GetRecords(question.Name, recordType) - if len(ips) == 0 { + ips, exists := p.recordStore.GetRecords(question.Name, recordType) + if !exists { + // Domain not found in local records, forward to upstream return nil } logger.Debug("Found %d local record(s) for %s", len(ips), question.Name) - // Create response message + // Create response message (NODATA if no records, otherwise with answers) response := new(dns.Msg) response.SetReply(query) response.Authoritative = true - // Add answer records + // Add answer records (loop is a no-op if ips is empty) for _, ip := range ips { var rr dns.RR if question.Qtype == dns.TypeA { @@ -730,8 +731,9 @@ func (p *DNSProxy) RemoveDNSRecord(domain string, ip net.IP) { p.recordStore.RemoveRecord(domain, ip) } -// GetDNSRecords returns all IP addresses for a domain and record type -func (p *DNSProxy) GetDNSRecords(domain string, recordType RecordType) []net.IP { +// GetDNSRecords returns all IP addresses for a domain and record type. +// The second return value indicates whether the domain exists. +func (p *DNSProxy) GetDNSRecords(domain string, recordType RecordType) ([]net.IP, bool) { return p.recordStore.GetRecords(domain, recordType) } diff --git a/dns/dns_records.go b/dns/dns_records.go index 5c62043..10bb7f3 100644 --- a/dns/dns_records.go +++ b/dns/dns_records.go @@ -24,41 +24,19 @@ type recordSet struct { AAAA []net.IP } -// domainTrieNode is a node in the trie for exact domain lookups (no wildcards in path) -type domainTrieNode struct { - children map[string]*domainTrieNode - data *recordSet -} - // DNSRecordStore manages local DNS records for A, AAAA, and PTR queries. -// Exact domains are stored in a trie for O(label count) lookup; wildcard patterns -// are in a separate map. Each domain/pattern has a single recordSet (A + AAAA). +// Exact domains are stored in a map; wildcard patterns are in a separate map. type DNSRecordStore struct { mu sync.RWMutex - root *domainTrieNode // trie root for exact lookups + exact map[string]*recordSet // normalized FQDN -> A/AAAA records wildcards map[string]*recordSet // wildcard pattern -> A/AAAA records ptrRecords map[string]string // IP address string -> domain name } -// domainToPath converts a FQDN to a trie path (reversed labels, e.g. "host.internal." -> ["internal", "host"]) -func domainToPath(domain string) []string { - domain = strings.ToLower(dns.Fqdn(domain)) - domain = strings.TrimSuffix(domain, ".") - if domain == "" { - return nil - } - labels := strings.Split(domain, ".") - path := make([]string, 0, len(labels)) - for i := len(labels) - 1; i >= 0; i-- { - path = append(path, labels[i]) - } - return path -} - // NewDNSRecordStore creates a new DNS record store func NewDNSRecordStore() *DNSRecordStore { return &DNSRecordStore{ - root: &domainTrieNode{children: make(map[string]*domainTrieNode)}, + exact: make(map[string]*recordSet), wildcards: make(map[string]*recordSet), ptrRecords: make(map[string]string), } @@ -84,36 +62,26 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error { return &net.ParseError{Type: "IP address", Text: ip.String()} } + // Choose the appropriate map based on whether this is a wildcard + m := s.exact if isWildcard { - if s.wildcards[domain] == nil { - s.wildcards[domain] = &recordSet{} - } - rs := s.wildcards[domain] - if isV4 { - rs.A = append(rs.A, ip) - } else { - rs.AAAA = append(rs.AAAA, ip) - } - return nil + m = s.wildcards } - path := domainToPath(domain) - node := s.root - for _, label := range path { - if node.children[label] == nil { - node.children[label] = &domainTrieNode{children: make(map[string]*domainTrieNode)} - } - node = node.children[label] - } - if node.data == nil { - node.data = &recordSet{} + if m[domain] == nil { + m[domain] = &recordSet{} } + rs := m[domain] if isV4 { - node.data.A = append(node.data.A, ip) + rs.A = append(rs.A, ip) } else { - node.data.AAAA = append(node.data.AAAA, ip) + rs.AAAA = append(rs.AAAA, ip) + } + + // Add PTR record for non-wildcard domains + if !isWildcard { + s.ptrRecords[ip.String()] = domain } - s.ptrRecords[ip.String()] = domain return nil } @@ -151,67 +119,55 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) { domain = strings.ToLower(dns.Fqdn(domain)) isWildcard := strings.ContainsAny(domain, "*?") + // Choose the appropriate map + m := s.exact if isWildcard { - if ip == nil { - delete(s.wildcards, domain) - return - } - rs := s.wildcards[domain] - if rs == nil { - return - } - if ip.To4() != nil { - rs.A = removeIP(rs.A, ip) - } else { - rs.AAAA = removeIP(rs.AAAA, ip) - } - if len(rs.A) == 0 && len(rs.AAAA) == 0 { - delete(s.wildcards, domain) - } - return + m = s.wildcards } - // Exact domain: find trie node - path := domainToPath(domain) - node := s.root - for _, label := range path { - node = node.children[label] - if node == nil { - return - } - } - if node.data == nil { + rs := m[domain] + if rs == nil { return } if ip == nil { - for _, ipAddr := range node.data.A { - if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain { - delete(s.ptrRecords, ipAddr.String()) + // Remove all records for this domain + if !isWildcard { + for _, ipAddr := range rs.A { + if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain { + delete(s.ptrRecords, ipAddr.String()) + } + } + for _, ipAddr := range rs.AAAA { + if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain { + delete(s.ptrRecords, ipAddr.String()) + } } } - for _, ipAddr := range node.data.AAAA { - if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain { - delete(s.ptrRecords, ipAddr.String()) - } - } - node.data = nil + delete(m, domain) return } + // Remove specific IP if ip.To4() != nil { - node.data.A = removeIP(node.data.A, ip) - if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain { - delete(s.ptrRecords, ip.String()) + rs.A = removeIP(rs.A, ip) + if !isWildcard { + if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain { + delete(s.ptrRecords, ip.String()) + } } } else { - node.data.AAAA = removeIP(node.data.AAAA, ip) - if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain { - delete(s.ptrRecords, ip.String()) + rs.AAAA = removeIP(rs.AAAA, ip) + if !isWildcard { + if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain { + delete(s.ptrRecords, ip.String()) + } } } - if len(node.data.A) == 0 && len(node.data.AAAA) == 0 { - node.data = nil + + // Clean up empty record sets + if len(rs.A) == 0 && len(rs.AAAA) == 0 { + delete(m, domain) } } @@ -223,55 +179,56 @@ func (s *DNSRecordStore) RemovePTRRecord(ip net.IP) { delete(s.ptrRecords, ip.String()) } -// GetRecords returns all IP addresses for a domain and record type -// First checks for exact match in the trie, then wildcard patterns -func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP { +// GetRecords returns all IP addresses for a domain and record type. +// The second return value indicates whether the domain exists at all +// (true = domain exists, use NODATA if no records; false = NXDOMAIN). +func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) ([]net.IP, bool) { s.mu.RLock() defer s.mu.RUnlock() domain = strings.ToLower(dns.Fqdn(domain)) - path := domainToPath(domain) - // Exact match: walk trie - node := s.root - for _, label := range path { - node = node.children[label] - if node == nil { - break - } - } - if node != nil && node.data != nil { + // Check exact match first + if rs, exists := s.exact[domain]; exists { var ips []net.IP if recordType == RecordTypeA { - ips = node.data.A + ips = rs.A } else { - ips = node.data.AAAA + ips = rs.AAAA } if len(ips) > 0 { out := make([]net.IP, len(ips)) copy(out, ips) - return out + return out, true } + // Domain exists but no records of this type + return nil, true } - // Wildcard match + // Check wildcard matches var records []net.IP + matched := false for pattern, rs := range s.wildcards { if !matchWildcard(pattern, domain) { continue } + matched = true if recordType == RecordTypeA { records = append(records, rs.A...) } else { records = append(records, rs.AAAA...) } } + + if !matched { + return nil, false + } if len(records) == 0 { - return nil + return nil, true } out := make([]net.IP, len(records)) copy(out, records) - return out + return out, true } // GetPTRRecord returns the domain name for a PTR record query @@ -295,30 +252,24 @@ func (s *DNSRecordStore) GetPTRRecord(domain string) (string, bool) { } // HasRecord checks if a domain has any records of the specified type -// Checks both exact matches (trie) and wildcard patterns +// Checks both exact matches and wildcard patterns func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool { s.mu.RLock() defer s.mu.RUnlock() domain = strings.ToLower(dns.Fqdn(domain)) - path := domainToPath(domain) - node := s.root - for _, label := range path { - node = node.children[label] - if node == nil { - break - } - } - if node != nil && node.data != nil { - if recordType == RecordTypeA && len(node.data.A) > 0 { + // Check exact match + if rs, exists := s.exact[domain]; exists { + if recordType == RecordTypeA && len(rs.A) > 0 { return true } - if recordType == RecordTypeAAAA && len(node.data.AAAA) > 0 { + if recordType == RecordTypeAAAA && len(rs.AAAA) > 0 { return true } } + // Check wildcard matches for pattern, rs := range s.wildcards { if !matchWildcard(pattern, domain) { continue @@ -353,7 +304,7 @@ func (s *DNSRecordStore) Clear() { s.mu.Lock() defer s.mu.Unlock() - s.root = &domainTrieNode{children: make(map[string]*domainTrieNode)} + s.exact = make(map[string]*recordSet) s.wildcards = make(map[string]*recordSet) s.ptrRecords = make(map[string]string) } diff --git a/dns/dns_records_test.go b/dns/dns_records_test.go index eae9372..963dcc1 100644 --- a/dns/dns_records_test.go +++ b/dns/dns_records_test.go @@ -183,25 +183,34 @@ func TestDNSRecordStoreWildcard(t *testing.T) { } // Test exact match takes precedence - ips := store.GetRecords("exact.autoco.internal.", RecordTypeA) + ips, exists := store.GetRecords("exact.autoco.internal.", RecordTypeA) + if !exists { + t.Error("Expected domain to exist") + } if len(ips) != 1 { t.Errorf("Expected 1 IP for exact match, got %d", len(ips)) } - if !ips[0].Equal(exactIP) { + if len(ips) > 0 && !ips[0].Equal(exactIP) { t.Errorf("Expected exact IP %v, got %v", exactIP, ips[0]) } // Test wildcard match - ips = store.GetRecords("host.autoco.internal.", RecordTypeA) + ips, exists = store.GetRecords("host.autoco.internal.", RecordTypeA) + if !exists { + t.Error("Expected wildcard match to exist") + } if len(ips) != 1 { t.Errorf("Expected 1 IP for wildcard match, got %d", len(ips)) } - if !ips[0].Equal(wildcardIP) { + if len(ips) > 0 && !ips[0].Equal(wildcardIP) { t.Errorf("Expected wildcard IP %v, got %v", wildcardIP, ips[0]) } // Test non-match (base domain) - ips = store.GetRecords("autoco.internal.", RecordTypeA) + ips, exists = store.GetRecords("autoco.internal.", RecordTypeA) + if exists { + t.Error("Expected base domain to not exist") + } if len(ips) != 0 { t.Errorf("Expected 0 IPs for base domain, got %d", len(ips)) } @@ -218,7 +227,10 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) { } // Test matching domain - ips := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA) + ips, exists := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA) + if !exists { + t.Error("Expected complex wildcard match to exist") + } if len(ips) != 1 { t.Errorf("Expected 1 IP for complex wildcard match, got %d", len(ips)) } @@ -227,13 +239,19 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) { } // Test non-matching domain (missing prefix) - ips = store.GetRecords("host-01.autoco.internal.", RecordTypeA) + ips, exists = store.GetRecords("host-01.autoco.internal.", RecordTypeA) + if exists { + t.Error("Expected domain without prefix to not exist") + } if len(ips) != 0 { t.Errorf("Expected 0 IPs for domain without prefix, got %d", len(ips)) } // Test non-matching domain (wrong ? position) - ips = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA) + ips, exists = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA) + if exists { + t.Error("Expected domain with wrong ? match to not exist") + } if len(ips) != 0 { t.Errorf("Expected 0 IPs for domain with wrong ? match, got %d", len(ips)) } @@ -250,7 +268,10 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) { } // Verify it exists - ips := store.GetRecords("host.autoco.internal.", RecordTypeA) + ips, exists := store.GetRecords("host.autoco.internal.", RecordTypeA) + if !exists { + t.Error("Expected domain to exist before removal") + } if len(ips) != 1 { t.Errorf("Expected 1 IP before removal, got %d", len(ips)) } @@ -259,7 +280,10 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) { store.RemoveRecord("*.autoco.internal", nil) // Verify it's gone - ips = store.GetRecords("host.autoco.internal.", RecordTypeA) + ips, exists = store.GetRecords("host.autoco.internal.", RecordTypeA) + if exists { + t.Error("Expected domain to not exist after removal") + } if len(ips) != 0 { t.Errorf("Expected 0 IPs after removal, got %d", len(ips)) } @@ -290,19 +314,19 @@ func TestDNSRecordStoreMultipleWildcards(t *testing.T) { } // Test domain matching only the prod pattern and the broad pattern - ips := store.GetRecords("host.prod.autoco.internal.", RecordTypeA) + ips, _ := store.GetRecords("host.prod.autoco.internal.", RecordTypeA) if len(ips) != 2 { t.Errorf("Expected 2 IPs (prod + broad), got %d", len(ips)) } // Test domain matching only the dev pattern and the broad pattern - ips = store.GetRecords("service.dev.autoco.internal.", RecordTypeA) + ips, _ = store.GetRecords("service.dev.autoco.internal.", RecordTypeA) if len(ips) != 2 { t.Errorf("Expected 2 IPs (dev + broad), got %d", len(ips)) } // Test domain matching only the broad pattern - ips = store.GetRecords("host.test.autoco.internal.", RecordTypeA) + ips, _ = store.GetRecords("host.test.autoco.internal.", RecordTypeA) if len(ips) != 1 { t.Errorf("Expected 1 IP (broad only), got %d", len(ips)) } @@ -319,7 +343,7 @@ func TestDNSRecordStoreIPv6Wildcard(t *testing.T) { } // Test wildcard match for IPv6 - ips := store.GetRecords("host.autoco.internal.", RecordTypeAAAA) + ips, _ := store.GetRecords("host.autoco.internal.", RecordTypeAAAA) if len(ips) != 1 { t.Errorf("Expected 1 IPv6 for wildcard match, got %d", len(ips)) } @@ -368,7 +392,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) { } for _, domain := range testCases { - ips := store.GetRecords(domain, RecordTypeA) + ips, _ := store.GetRecords(domain, RecordTypeA) if len(ips) != 1 { t.Errorf("Expected 1 IP for domain %q, got %d", domain, len(ips)) } @@ -392,7 +416,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) { } for _, domain := range wildcardTestCases { - ips := store.GetRecords(domain, RecordTypeA) + ips, _ := store.GetRecords(domain, RecordTypeA) if len(ips) != 1 { t.Errorf("Expected 1 IP for wildcard domain %q, got %d", domain, len(ips)) } @@ -403,7 +427,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) { // Test removal with different case store.RemoveRecord("MYHOST.AUTOCO.INTERNAL", nil) - ips := store.GetRecords("myhost.autoco.internal.", RecordTypeA) + ips, _ := store.GetRecords("myhost.autoco.internal.", RecordTypeA) if len(ips) != 0 { t.Errorf("Expected 0 IPs after removal, got %d", len(ips)) } @@ -752,7 +776,7 @@ func TestAutomaticPTRRecordOnRemove(t *testing.T) { } // Verify A record is also gone - ips := store.GetRecords(domain, RecordTypeA) + ips, _ := store.GetRecords(domain, RecordTypeA) if len(ips) != 0 { t.Errorf("Expected A record to be removed, got %d records", len(ips)) } From ae88766d85926e3643fc7ec0c6452b0270da8167 Mon Sep 17 00:00:00 2001 From: Laurence Date: Sat, 28 Feb 2026 10:22:37 +0000 Subject: [PATCH 08/15] test(dns): add dns test cases for nodata --- dns/dns_proxy_test.go | 178 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 178 insertions(+) create mode 100644 dns/dns_proxy_test.go diff --git a/dns/dns_proxy_test.go b/dns/dns_proxy_test.go new file mode 100644 index 0000000..4a1d9f9 --- /dev/null +++ b/dns/dns_proxy_test.go @@ -0,0 +1,178 @@ +package dns + +import ( + "net" + "testing" + + "github.com/miekg/dns" +) + +func TestCheckLocalRecordsNODATAForAAAA(t *testing.T) { + proxy := &DNSProxy{ + recordStore: NewDNSRecordStore(), + } + + // Add an A record for a domain (no AAAA record) + ip := net.ParseIP("10.0.0.1") + err := proxy.recordStore.AddRecord("myservice.internal", ip) + if err != nil { + t.Fatalf("Failed to add A record: %v", err) + } + + // Query AAAA for domain with only A record - should return NODATA + query := new(dns.Msg) + query.SetQuestion("myservice.internal.", dns.TypeAAAA) + response := proxy.checkLocalRecords(query, query.Question[0]) + + if response == nil { + t.Fatal("Expected NODATA response, got nil (would forward to upstream)") + } + if response.Rcode != dns.RcodeSuccess { + t.Errorf("Expected Rcode NOERROR (0), got %d", response.Rcode) + } + if len(response.Answer) != 0 { + t.Errorf("Expected empty answer section for NODATA, got %d answers", len(response.Answer)) + } + if !response.Authoritative { + t.Error("Expected response to be authoritative") + } + + // Query A for same domain - should return the record + query = new(dns.Msg) + query.SetQuestion("myservice.internal.", dns.TypeA) + response = proxy.checkLocalRecords(query, query.Question[0]) + + if response == nil { + t.Fatal("Expected response with A record, got nil") + } + if len(response.Answer) != 1 { + t.Fatalf("Expected 1 answer, got %d", len(response.Answer)) + } + aRecord, ok := response.Answer[0].(*dns.A) + if !ok { + t.Fatal("Expected A record in answer") + } + if !aRecord.A.Equal(ip.To4()) { + t.Errorf("Expected IP %v, got %v", ip.To4(), aRecord.A) + } +} + +func TestCheckLocalRecordsNODATAForA(t *testing.T) { + proxy := &DNSProxy{ + recordStore: NewDNSRecordStore(), + } + + // Add an AAAA record for a domain (no A record) + ip := net.ParseIP("2001:db8::1") + err := proxy.recordStore.AddRecord("ipv6only.internal", ip) + if err != nil { + t.Fatalf("Failed to add AAAA record: %v", err) + } + + // Query A for domain with only AAAA record - should return NODATA + query := new(dns.Msg) + query.SetQuestion("ipv6only.internal.", dns.TypeA) + response := proxy.checkLocalRecords(query, query.Question[0]) + + if response == nil { + t.Fatal("Expected NODATA response, got nil") + } + if response.Rcode != dns.RcodeSuccess { + t.Errorf("Expected Rcode NOERROR (0), got %d", response.Rcode) + } + if len(response.Answer) != 0 { + t.Errorf("Expected empty answer section, got %d answers", len(response.Answer)) + } + if !response.Authoritative { + t.Error("Expected response to be authoritative") + } + + // Query AAAA for same domain - should return the record + query = new(dns.Msg) + query.SetQuestion("ipv6only.internal.", dns.TypeAAAA) + response = proxy.checkLocalRecords(query, query.Question[0]) + + if response == nil { + t.Fatal("Expected response with AAAA record, got nil") + } + if len(response.Answer) != 1 { + t.Fatalf("Expected 1 answer, got %d", len(response.Answer)) + } + aaaaRecord, ok := response.Answer[0].(*dns.AAAA) + if !ok { + t.Fatal("Expected AAAA record in answer") + } + if !aaaaRecord.AAAA.Equal(ip) { + t.Errorf("Expected IP %v, got %v", ip, aaaaRecord.AAAA) + } +} + +func TestCheckLocalRecordsNonExistentDomain(t *testing.T) { + proxy := &DNSProxy{ + recordStore: NewDNSRecordStore(), + } + + // Add a record so the store isn't empty + err := proxy.recordStore.AddRecord("exists.internal", net.ParseIP("10.0.0.1")) + if err != nil { + t.Fatalf("Failed to add record: %v", err) + } + + // Query A for non-existent domain - should return nil (forward to upstream) + query := new(dns.Msg) + query.SetQuestion("unknown.internal.", dns.TypeA) + response := proxy.checkLocalRecords(query, query.Question[0]) + + if response != nil { + t.Error("Expected nil for non-existent domain, got response") + } + + // Query AAAA for non-existent domain - should also return nil + query = new(dns.Msg) + query.SetQuestion("unknown.internal.", dns.TypeAAAA) + response = proxy.checkLocalRecords(query, query.Question[0]) + + if response != nil { + t.Error("Expected nil for non-existent domain, got response") + } +} + +func TestCheckLocalRecordsNODATAWildcard(t *testing.T) { + proxy := &DNSProxy{ + recordStore: NewDNSRecordStore(), + } + + // Add a wildcard A record (no AAAA) + ip := net.ParseIP("10.0.0.1") + err := proxy.recordStore.AddRecord("*.wildcard.internal", ip) + if err != nil { + t.Fatalf("Failed to add wildcard A record: %v", err) + } + + // Query AAAA for wildcard-matched domain - should return NODATA + query := new(dns.Msg) + query.SetQuestion("host.wildcard.internal.", dns.TypeAAAA) + response := proxy.checkLocalRecords(query, query.Question[0]) + + if response == nil { + t.Fatal("Expected NODATA response for wildcard match, got nil") + } + if response.Rcode != dns.RcodeSuccess { + t.Errorf("Expected Rcode NOERROR (0), got %d", response.Rcode) + } + if len(response.Answer) != 0 { + t.Errorf("Expected empty answer section, got %d answers", len(response.Answer)) + } + + // Query A for wildcard-matched domain - should return the record + query = new(dns.Msg) + query.SetQuestion("host.wildcard.internal.", dns.TypeA) + response = proxy.checkLocalRecords(query, query.Question[0]) + + if response == nil { + t.Fatal("Expected response with A record, got nil") + } + if len(response.Answer) != 1 { + t.Fatalf("Expected 1 answer, got %d", len(response.Answer)) + } +} From e2690bcc03514ecca27a2a13ca637922f792a378 Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 6 Mar 2026 16:19:00 -0800 Subject: [PATCH 09/15] Store site id --- dns/dns_proxy.go | 4 ++-- dns/dns_proxy_test.go | 8 ++++---- dns/dns_records.go | 34 ++++++++++++++++++++++++++++++---- dns/dns_records_test.go | 40 ++++++++++++++++++++-------------------- peers/manager.go | 6 +++--- 5 files changed, 59 insertions(+), 33 deletions(-) diff --git a/dns/dns_proxy.go b/dns/dns_proxy.go index 27770e4..7b7858c 100644 --- a/dns/dns_proxy.go +++ b/dns/dns_proxy.go @@ -721,8 +721,8 @@ func (p *DNSProxy) runPacketSender() { // AddDNSRecord adds a DNS record to the local store // domain should be a domain name (e.g., "example.com" or "example.com.") // ip should be a valid IPv4 or IPv6 address -func (p *DNSProxy) AddDNSRecord(domain string, ip net.IP) error { - return p.recordStore.AddRecord(domain, ip) +func (p *DNSProxy) AddDNSRecord(domain string, ip net.IP, siteId int) error { + return p.recordStore.AddRecord(domain, ip, siteId) } // RemoveDNSRecord removes a DNS record from the local store diff --git a/dns/dns_proxy_test.go b/dns/dns_proxy_test.go index 4a1d9f9..9eecad7 100644 --- a/dns/dns_proxy_test.go +++ b/dns/dns_proxy_test.go @@ -14,7 +14,7 @@ func TestCheckLocalRecordsNODATAForAAAA(t *testing.T) { // Add an A record for a domain (no AAAA record) ip := net.ParseIP("10.0.0.1") - err := proxy.recordStore.AddRecord("myservice.internal", ip) + err := proxy.recordStore.AddRecord("myservice.internal", ip, 0) if err != nil { t.Fatalf("Failed to add A record: %v", err) } @@ -64,7 +64,7 @@ func TestCheckLocalRecordsNODATAForA(t *testing.T) { // Add an AAAA record for a domain (no A record) ip := net.ParseIP("2001:db8::1") - err := proxy.recordStore.AddRecord("ipv6only.internal", ip) + err := proxy.recordStore.AddRecord("ipv6only.internal", ip, 0) if err != nil { t.Fatalf("Failed to add AAAA record: %v", err) } @@ -113,7 +113,7 @@ func TestCheckLocalRecordsNonExistentDomain(t *testing.T) { } // Add a record so the store isn't empty - err := proxy.recordStore.AddRecord("exists.internal", net.ParseIP("10.0.0.1")) + err := proxy.recordStore.AddRecord("exists.internal", net.ParseIP("10.0.0.1"), 0) if err != nil { t.Fatalf("Failed to add record: %v", err) } @@ -144,7 +144,7 @@ func TestCheckLocalRecordsNODATAWildcard(t *testing.T) { // Add a wildcard A record (no AAAA) ip := net.ParseIP("10.0.0.1") - err := proxy.recordStore.AddRecord("*.wildcard.internal", ip) + err := proxy.recordStore.AddRecord("*.wildcard.internal", ip, 0) if err != nil { t.Fatalf("Failed to add wildcard A record: %v", err) } diff --git a/dns/dns_records.go b/dns/dns_records.go index 10bb7f3..c52c08e 100644 --- a/dns/dns_records.go +++ b/dns/dns_records.go @@ -20,8 +20,9 @@ const ( // recordSet holds A and AAAA records for a single domain or wildcard pattern type recordSet struct { - A []net.IP - AAAA []net.IP + A []net.IP + AAAA []net.IP + SiteId int } // DNSRecordStore manages local DNS records for A, AAAA, and PTR queries. @@ -46,8 +47,9 @@ func NewDNSRecordStore() *DNSRecordStore { // domain should be in FQDN format (e.g., "example.com.") // domain can contain wildcards: * (0+ chars) and ? (exactly 1 char) // ip should be a valid IPv4 or IPv6 address +// siteId is the site that owns this alias/domain // Automatically adds a corresponding PTR record for non-wildcard domains -func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error { +func (s *DNSRecordStore) AddRecord(domain string, ip net.IP, siteId int) error { s.mu.Lock() defer s.mu.Unlock() @@ -69,7 +71,7 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error { } if m[domain] == nil { - m[domain] = &recordSet{} + m[domain] = &recordSet{SiteId: siteId} } rs := m[domain] if isV4 { @@ -179,6 +181,30 @@ func (s *DNSRecordStore) RemovePTRRecord(ip net.IP) { delete(s.ptrRecords, ip.String()) } +// GetSiteIdForDomain returns the siteId associated with the given domain. +// It checks exact matches first, then wildcard patterns. +// The second return value is false if the domain is not found in local records. +func (s *DNSRecordStore) GetSiteIdForDomain(domain string) (int, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + + domain = strings.ToLower(dns.Fqdn(domain)) + + // Check exact match first + if rs, exists := s.exact[domain]; exists { + return rs.SiteId, true + } + + // Check wildcard matches + for pattern, rs := range s.wildcards { + if matchWildcard(pattern, domain) { + return rs.SiteId, true + } + } + + return 0, false +} + // GetRecords returns all IP addresses for a domain and record type. // The second return value indicates whether the domain exists at all // (true = domain exists, use NODATA if no records; false = NXDOMAIN). diff --git a/dns/dns_records_test.go b/dns/dns_records_test.go index 963dcc1..0b4481d 100644 --- a/dns/dns_records_test.go +++ b/dns/dns_records_test.go @@ -170,14 +170,14 @@ func TestDNSRecordStoreWildcard(t *testing.T) { // Add wildcard records wildcardIP := net.ParseIP("10.0.0.1") - err := store.AddRecord("*.autoco.internal", wildcardIP) + err := store.AddRecord("*.autoco.internal", wildcardIP, 0) if err != nil { t.Fatalf("Failed to add wildcard record: %v", err) } // Add exact record exactIP := net.ParseIP("10.0.0.2") - err = store.AddRecord("exact.autoco.internal", exactIP) + err = store.AddRecord("exact.autoco.internal", exactIP, 0) if err != nil { t.Fatalf("Failed to add exact record: %v", err) } @@ -221,7 +221,7 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) { // Add complex wildcard pattern ip1 := net.ParseIP("10.0.0.1") - err := store.AddRecord("*.host-0?.autoco.internal", ip1) + err := store.AddRecord("*.host-0?.autoco.internal", ip1, 0) if err != nil { t.Fatalf("Failed to add wildcard record: %v", err) } @@ -262,7 +262,7 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) { // Add wildcard record ip := net.ParseIP("10.0.0.1") - err := store.AddRecord("*.autoco.internal", ip) + err := store.AddRecord("*.autoco.internal", ip, 0) if err != nil { t.Fatalf("Failed to add wildcard record: %v", err) } @@ -297,18 +297,18 @@ func TestDNSRecordStoreMultipleWildcards(t *testing.T) { ip2 := net.ParseIP("10.0.0.2") ip3 := net.ParseIP("10.0.0.3") - err := store.AddRecord("*.prod.autoco.internal", ip1) + err := store.AddRecord("*.prod.autoco.internal", ip1, 0) if err != nil { t.Fatalf("Failed to add first wildcard: %v", err) } - err = store.AddRecord("*.dev.autoco.internal", ip2) + err = store.AddRecord("*.dev.autoco.internal", ip2, 0) if err != nil { t.Fatalf("Failed to add second wildcard: %v", err) } // Add a broader wildcard that matches both - err = store.AddRecord("*.autoco.internal", ip3) + err = store.AddRecord("*.autoco.internal", ip3, 0) if err != nil { t.Fatalf("Failed to add third wildcard: %v", err) } @@ -337,7 +337,7 @@ func TestDNSRecordStoreIPv6Wildcard(t *testing.T) { // Add IPv6 wildcard record ip := net.ParseIP("2001:db8::1") - err := store.AddRecord("*.autoco.internal", ip) + err := store.AddRecord("*.autoco.internal", ip, 0) if err != nil { t.Fatalf("Failed to add IPv6 wildcard record: %v", err) } @@ -357,7 +357,7 @@ func TestHasRecordWildcard(t *testing.T) { // Add wildcard record ip := net.ParseIP("10.0.0.1") - err := store.AddRecord("*.autoco.internal", ip) + err := store.AddRecord("*.autoco.internal", ip, 0) if err != nil { t.Fatalf("Failed to add wildcard record: %v", err) } @@ -378,7 +378,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) { // Add record with mixed case ip := net.ParseIP("10.0.0.1") - err := store.AddRecord("MyHost.AutoCo.Internal", ip) + err := store.AddRecord("MyHost.AutoCo.Internal", ip, 0) if err != nil { t.Fatalf("Failed to add mixed case record: %v", err) } @@ -403,7 +403,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) { // Test wildcard with mixed case wildcardIP := net.ParseIP("10.0.0.2") - err = store.AddRecord("*.Example.Com", wildcardIP) + err = store.AddRecord("*.Example.Com", wildcardIP, 0) if err != nil { t.Fatalf("Failed to add mixed case wildcard: %v", err) } @@ -689,7 +689,7 @@ func TestClearPTRRecords(t *testing.T) { store.AddPTRRecord(ip2, "host2.example.com.") // Add some A records too - store.AddRecord("test.example.com.", net.ParseIP("10.0.0.1")) + store.AddRecord("test.example.com.", net.ParseIP("10.0.0.1"), 0) // Verify PTR records exist if !store.HasPTRRecord("1.1.168.192.in-addr.arpa.") { @@ -719,7 +719,7 @@ func TestAutomaticPTRRecordOnAdd(t *testing.T) { // Add an A record - should automatically add PTR record domain := "host.example.com." ip := net.ParseIP("192.168.1.100") - err := store.AddRecord(domain, ip) + err := store.AddRecord(domain, ip, 0) if err != nil { t.Fatalf("Failed to add A record: %v", err) } @@ -737,7 +737,7 @@ func TestAutomaticPTRRecordOnAdd(t *testing.T) { // Add AAAA record - should also automatically add PTR record domain6 := "ipv6host.example.com." ip6 := net.ParseIP("2001:db8::1") - err = store.AddRecord(domain6, ip6) + err = store.AddRecord(domain6, ip6, 0) if err != nil { t.Fatalf("Failed to add AAAA record: %v", err) } @@ -759,7 +759,7 @@ func TestAutomaticPTRRecordOnRemove(t *testing.T) { // Add an A record (with automatic PTR) domain := "host.example.com." ip := net.ParseIP("192.168.1.100") - store.AddRecord(domain, ip) + store.AddRecord(domain, ip, 0) // Verify PTR exists reverseDomain := "100.1.168.192.in-addr.arpa." @@ -789,8 +789,8 @@ func TestAutomaticPTRRecordOnRemoveAll(t *testing.T) { domain := "host.example.com." ip1 := net.ParseIP("192.168.1.100") ip2 := net.ParseIP("192.168.1.101") - store.AddRecord(domain, ip1) - store.AddRecord(domain, ip2) + store.AddRecord(domain, ip1, 0) + store.AddRecord(domain, ip2, 0) // Verify both PTR records exist reverseDomain1 := "100.1.168.192.in-addr.arpa." @@ -820,7 +820,7 @@ func TestNoPTRForWildcardRecords(t *testing.T) { // Add wildcard record - should NOT create PTR record domain := "*.example.com." ip := net.ParseIP("192.168.1.100") - err := store.AddRecord(domain, ip) + err := store.AddRecord(domain, ip, 0) if err != nil { t.Fatalf("Failed to add wildcard record: %v", err) } @@ -844,7 +844,7 @@ func TestPTRRecordOverwrite(t *testing.T) { // Add first domain with IP domain1 := "host1.example.com." ip := net.ParseIP("192.168.1.100") - store.AddRecord(domain1, ip) + store.AddRecord(domain1, ip, 0) // Verify PTR points to first domain reverseDomain := "100.1.168.192.in-addr.arpa." @@ -858,7 +858,7 @@ func TestPTRRecordOverwrite(t *testing.T) { // Add second domain with same IP - should overwrite PTR domain2 := "host2.example.com." - store.AddRecord(domain2, ip) + store.AddRecord(domain2, ip, 0) // Verify PTR now points to second domain (last one added) result, ok = store.GetPTRRecord(reverseDomain) diff --git a/peers/manager.go b/peers/manager.go index 0566775..e9925eb 100644 --- a/peers/manager.go +++ b/peers/manager.go @@ -144,7 +144,7 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error { if address == nil { continue } - pm.dnsProxy.AddDNSRecord(alias.Alias, address) + pm.dnsProxy.AddDNSRecord(alias.Alias, address, siteConfig.SiteId) } monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0] @@ -433,7 +433,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error { if address == nil { continue } - pm.dnsProxy.AddDNSRecord(alias.Alias, address) + pm.dnsProxy.AddDNSRecord(alias.Alias, address, siteConfig.SiteId) } pm.peerMonitor.UpdateHolepunchEndpoint(siteConfig.SiteId, siteConfig.Endpoint) @@ -713,7 +713,7 @@ func (pm *PeerManager) AddAlias(siteId int, alias Alias) error { address := net.ParseIP(alias.AliasAddress) if address != nil { - pm.dnsProxy.AddDNSRecord(alias.Alias, address) + pm.dnsProxy.AddDNSRecord(alias.Alias, address, siteId) } // Add an allowed IP for the alias From 3f258d3500abcf77cc77554d2e2468905c159c3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Gilerson?= Date: Sun, 8 Mar 2026 01:59:39 +0100 Subject: [PATCH 10/15] Fix crash when peer has nil publicKey in site config Skip sites with empty/nil publicKey instead of passing them to the WireGuard UAPI layer, which expects a valid 64-char hex string. A nil key occurs when a Newt site has never connected. Previously this caused all sites to fail with "hex string does not fit the slice". --- olm/connect.go | 6 ++++++ olm/peer.go | 5 +++++ 2 files changed, 11 insertions(+) diff --git a/olm/connect.go b/olm/connect.go index dc05d1f..afa5c4b 100644 --- a/olm/connect.go +++ b/olm/connect.go @@ -172,6 +172,12 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) { for i := range wgData.Sites { site := wgData.Sites[i] + + if site.PublicKey == "" { + logger.Warn("Skipping site %d (%s): no public key available (site may not be connected)", site.SiteId, site.Name) + continue + } + var siteEndpoint string // here we are going to take the relay endpoint if it exists which means we requested a relay for this peer if site.RelayEndpoint != "" { diff --git a/olm/peer.go b/olm/peer.go index 8007272..9d753b7 100644 --- a/olm/peer.go +++ b/olm/peer.go @@ -37,6 +37,11 @@ func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) { return } + if siteConfig.PublicKey == "" { + logger.Warn("Skipping add-peer for site %d (%s): no public key available (site may not be connected)", siteConfig.SiteId, siteConfig.Name) + return + } + _ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it if err := o.peerManager.AddPeer(siteConfig); err != nil { From 22cd02ae15c317c6684b238d50c817e9e08bbd56 Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 11 Mar 2026 15:56:51 -0700 Subject: [PATCH 11/15] Alias jit handler --- dns/dns_proxy.go | 23 +++++++++++++++++++++++ olm/connect.go | 31 +++++++++++++++++++++++++++++++ olm/olm.go | 11 +++++++---- olm/peer.go | 11 +++++++++++ 4 files changed, 72 insertions(+), 4 deletions(-) diff --git a/dns/dns_proxy.go b/dns/dns_proxy.go index 7b7858c..7a69f53 100644 --- a/dns/dns_proxy.go +++ b/dns/dns_proxy.go @@ -45,6 +45,11 @@ type DNSProxy struct { tunnelActivePorts map[uint16]bool tunnelPortsLock sync.Mutex + // jitHandler is called when a local record is resolved for a site that may not be + // connected yet, giving the caller a chance to initiate a JIT connection. + // It is invoked asynchronously so it never blocks DNS resolution. + jitHandler func(siteId int) + ctx context.Context cancel context.CancelFunc wg sync.WaitGroup @@ -384,6 +389,16 @@ func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clie response = p.checkLocalRecords(msg, question) } + // If a local A/AAAA record was found, notify the JIT handler so that the owning + // site can be connected on-demand if it is not yet active. + if response != nil && p.jitHandler != nil && + (question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA) { + if siteId, ok := p.recordStore.GetSiteIdForDomain(question.Name); ok && siteId != 0 { + handler := p.jitHandler + go handler(siteId) + } + } + // If no local records, forward to upstream if response == nil { logger.Debug("No local record for %s, forwarding upstream to %v", question.Name, p.upstreamDNS) @@ -718,6 +733,14 @@ func (p *DNSProxy) runPacketSender() { } } +// SetJITHandler registers a callback that is invoked whenever a local DNS record is +// resolved for an A or AAAA query. The siteId identifies which site owns the record. +// The handler is called in its own goroutine so it must be safe to call concurrently. +// Pass nil to disable JIT notifications. +func (p *DNSProxy) SetJITHandler(handler func(siteId int)) { + p.jitHandler = handler +} + // AddDNSRecord adds a DNS record to the local store // domain should be a domain name (e.g., "example.com" or "example.com.") // ip should be a valid IPv4 or IPv6 address diff --git a/olm/connect.go b/olm/connect.go index dc05d1f..1e00ee2 100644 --- a/olm/connect.go +++ b/olm/connect.go @@ -7,6 +7,7 @@ import ( "runtime" "strconv" "strings" + "time" "github.com/fosrl/newt/logger" "github.com/fosrl/newt/network" @@ -196,6 +197,36 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) { logger.Error("Failed to start DNS proxy: %v", err) } + // Register JIT handler: when the DNS proxy resolves a local record, check whether + // the owning site is already connected and, if not, initiate a JIT connection. + o.dnsProxy.SetJITHandler(func(siteId int) { + if o.peerManager == nil || o.websocket == nil { + return + } + + // Site already has an active peer connection - nothing to do. + if _, exists := o.peerManager.GetPeer(siteId); exists { + return + } + + o.peerSendMu.Lock() + defer o.peerSendMu.Unlock() + + // A JIT request for this site is already in-flight - avoid duplicate sends. + if _, pending := o.jitPendingSites[siteId]; pending { + return + } + + chainId := generateChainId() + logger.Info("DNS-triggered JIT connect for site %d (chainId=%s)", siteId, chainId) + stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/init", map[string]interface{}{ + "siteId": siteId, + "chainId": chainId, + }, 2*time.Second, 10) + o.stopPeerInits[chainId] = stopFunc + o.jitPendingSites[siteId] = chainId + }) + if o.tunnelConfig.OverrideDNS { // Set up DNS override to use our DNS proxy if err := dnsOverride.SetupDNSOverride(o.tunnelConfig.InterfaceName, o.dnsProxy.GetProxyIP()); err != nil { diff --git a/olm/olm.go b/olm/olm.go index b2843d2..8b01f9d 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -67,8 +67,9 @@ type Olm struct { stopRegister func() updateRegister func(newData any) - stopPeerSends map[string]func() - stopPeerInits map[string]func() + stopPeerSends map[string]func() + stopPeerInits map[string]func() + jitPendingSites map[int]string // siteId -> chainId for in-flight JIT requests peerSendMu sync.Mutex // WaitGroup to track tunnel lifecycle @@ -181,8 +182,9 @@ func Init(ctx context.Context, config OlmConfig) (*Olm, error) { olmCtx: ctx, apiServer: apiServer, olmConfig: config, - stopPeerSends: make(map[string]func()), - stopPeerInits: make(map[string]func()), + stopPeerSends: make(map[string]func()), + stopPeerInits: make(map[string]func()), + jitPendingSites: make(map[int]string), } newOlm.registerAPICallbacks() @@ -560,6 +562,7 @@ func (o *Olm) Close() { } } o.stopPeerSends = make(map[string]func()) + o.jitPendingSites = make(map[int]string) o.peerSendMu.Unlock() // send a disconnect message to the cloud to show disconnected diff --git a/olm/peer.go b/olm/peer.go index 9f02bb2..da5a884 100644 --- a/olm/peer.go +++ b/olm/peer.go @@ -272,6 +272,9 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) { stop() delete(o.stopPeerInits, handshakeData.ChainId) } + // If this chain was initiated by a DNS-triggered JIT request, clear the + // pending entry so the site can be re-triggered if needed in the future. + delete(o.jitPendingSites, handshakeData.SiteId) o.peerSendMu.Unlock() } @@ -353,6 +356,14 @@ func (o *Olm) handleCancelChain(msg websocket.WSMessage) { delete(o.stopPeerInits, cancelData.ChainId) found = true } + // If this chain was a DNS-triggered JIT request, clear the pending entry so + // the site can be re-triggered on the next DNS lookup. + for siteId, chainId := range o.jitPendingSites { + if chainId == cancelData.ChainId { + delete(o.jitPendingSites, siteId) + break + } + } if stop, ok := o.stopPeerSends[cancelData.ChainId]; ok { stop() From c2b5ef96a464af23745f7e084e0a4e831f072d57 Mon Sep 17 00:00:00 2001 From: Owen Date: Thu, 12 Mar 2026 17:26:46 -0700 Subject: [PATCH 12/15] Jit of aliases working --- dns/dns_proxy.go | 1 + dns/dns_records.go | 11 +++++++++++ olm/connect.go | 28 +++++++++++++--------------- peers/manager.go | 22 ++++++++++++++-------- 4 files changed, 39 insertions(+), 23 deletions(-) diff --git a/dns/dns_proxy.go b/dns/dns_proxy.go index 7a69f53..9451ba8 100644 --- a/dns/dns_proxy.go +++ b/dns/dns_proxy.go @@ -745,6 +745,7 @@ func (p *DNSProxy) SetJITHandler(handler func(siteId int)) { // domain should be a domain name (e.g., "example.com" or "example.com.") // ip should be a valid IPv4 or IPv6 address func (p *DNSProxy) AddDNSRecord(domain string, ip net.IP, siteId int) error { + logger.Debug("Adding dns record for domain %s with IP %s (siteId=%d)", domain, ip.String(), siteId) return p.recordStore.AddRecord(domain, ip, siteId) } diff --git a/dns/dns_records.go b/dns/dns_records.go index c52c08e..270bae6 100644 --- a/dns/dns_records.go +++ b/dns/dns_records.go @@ -75,8 +75,18 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP, siteId int) error { } rs := m[domain] if isV4 { + for _, existing := range rs.A { + if existing.Equal(ip) { + return nil + } + } rs.A = append(rs.A, ip) } else { + for _, existing := range rs.AAAA { + if existing.Equal(ip) { + return nil + } + } rs.AAAA = append(rs.AAAA, ip) } @@ -87,6 +97,7 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP, siteId int) error { return nil } + // AddPTRRecord adds a PTR record mapping an IP address to a domain name // ip should be a valid IPv4 or IPv6 address // domain should be in FQDN format (e.g., "example.com.") diff --git a/olm/connect.go b/olm/connect.go index f5a0ccd..3a2000c 100644 --- a/olm/connect.go +++ b/olm/connect.go @@ -175,21 +175,19 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) { for i := range wgData.Sites { site := wgData.Sites[i] - if site.PublicKey == "" { - logger.Warn("Skipping site %d (%s): no public key available (site may not be connected)", site.SiteId, site.Name) - continue + if site.PublicKey != "" { + var siteEndpoint string + // here we are going to take the relay endpoint if it exists which means we requested a relay for this peer + if site.RelayEndpoint != "" { + siteEndpoint = site.RelayEndpoint + } else { + siteEndpoint = site.Endpoint + } + + o.apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false) } - var siteEndpoint string - // here we are going to take the relay endpoint if it exists which means we requested a relay for this peer - if site.RelayEndpoint != "" { - siteEndpoint = site.RelayEndpoint - } else { - siteEndpoint = site.Endpoint - } - - o.apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false) - + // we still call this to add the aliases for jit lookup but we just do that then pass inside. need to skip the above so we dont add to the api if err := o.peerManager.AddPeer(site); err != nil { logger.Error("Failed to add peer: %v", err) return @@ -311,12 +309,12 @@ func (o *Olm) handleTerminate(msg websocket.WSMessage) { logger.Error("Error unmarshaling terminate error data: %v", err) } else { logger.Info("Terminate reason (code: %s): %s", errorData.Code, errorData.Message) - + if errorData.Code == "TERMINATED_INACTIVITY" { logger.Info("Ignoring...") return } - + // Set the olm error in the API server so it can be exposed via status o.apiServer.SetOlmError(errorData.Code, errorData.Message) } diff --git a/peers/manager.go b/peers/manager.go index c5bb291..9cc1e75 100644 --- a/peers/manager.go +++ b/peers/manager.go @@ -110,6 +110,19 @@ func (pm *PeerManager) GetAllPeers() []SiteConfig { func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error { pm.mu.Lock() defer pm.mu.Unlock() + + for _, alias := range siteConfig.Aliases { + address := net.ParseIP(alias.AliasAddress) + if address == nil { + continue + } + pm.dnsProxy.AddDNSRecord(alias.Alias, address, siteConfig.SiteId) + } + + if siteConfig.PublicKey == "" { + logger.Debug("Skip adding site %d because no pub key", siteConfig.SiteId) + return nil + } // build the allowed IPs list from the remote subnets and aliases and add them to the peer allowedIPs := make([]string, 0, len(siteConfig.RemoteSubnets)+len(siteConfig.Aliases)) @@ -143,14 +156,7 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error { if err := network.AddRoutes(siteConfig.RemoteSubnets, pm.interfaceName); err != nil { logger.Error("Failed to add routes for remote subnets: %v", err) } - for _, alias := range siteConfig.Aliases { - address := net.ParseIP(alias.AliasAddress) - if address == nil { - continue - } - pm.dnsProxy.AddDNSRecord(alias.Alias, address, siteConfig.SiteId) - } - + monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0] monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port From 3de8dc9fc292db886d29c011bf6bba6ee458526a Mon Sep 17 00:00:00 2001 From: Owen Date: Thu, 12 Mar 2026 17:49:12 -0700 Subject: [PATCH 13/15] Add optional compression --- websocket/client.go | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/websocket/client.go b/websocket/client.go index dcf6acd..35b76b9 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -2,6 +2,7 @@ package websocket import ( "bytes" + "compress/gzip" "crypto/tls" "crypto/x509" "encoding/json" @@ -803,8 +804,7 @@ func (c *Client) readPumpWithDisconnectDetection() { case <-c.done: return default: - var msg WSMessage - err := c.conn.ReadJSON(&msg) + messageType, p, err := c.conn.ReadMessage() if err != nil { // Check if we're shutting down or explicitly disconnected before logging error select { @@ -829,6 +829,30 @@ func (c *Client) readPumpWithDisconnectDetection() { } } + // Decompress binary frames (gzip-compressed JSON) + var data []byte + if messageType == websocket.BinaryMessage { + gr, gzErr := gzip.NewReader(bytes.NewReader(p)) + if gzErr != nil { + logger.Error("websocket: failed to create gzip reader: %v", gzErr) + continue + } + data, gzErr = io.ReadAll(gr) + gr.Close() + if gzErr != nil { + logger.Error("websocket: failed to decompress message: %v", gzErr) + continue + } + } else { + data = p + } + + var msg WSMessage + if err = json.Unmarshal(data, &msg); err != nil { + logger.Error("websocket: failed to parse message: %v", err) + continue + } + // Update config version from incoming message c.setConfigVersion(msg.ConfigVersion) From 4bc0508c7dd7d5571f00883f69244be8abb1136c Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 16 Mar 2026 13:50:21 -0700 Subject: [PATCH 14/15] Remove redundant info --- olm/olm.go | 1 - websocket/client.go | 4 +--- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/olm/olm.go b/olm/olm.go index 96b0817..a458f8a 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -375,7 +375,6 @@ func (o *Olm) StartTunnel(config TunnelConfig) { config.OrgID, config.Endpoint, 30*time.Second, // 30 seconds - config.PingTimeoutDuration, websocket.WithPingDataProvider(func() map[string]any { o.metaMu.Lock() defer o.metaMu.Unlock() diff --git a/websocket/client.go b/websocket/client.go index 35b76b9..3b4e894 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -83,7 +83,6 @@ type Client struct { isDisconnected bool // Flag to track if client is intentionally disconnected reconnectMux sync.RWMutex pingInterval time.Duration - pingTimeout time.Duration onConnect func() error onTokenUpdate func(token string, exitNodes []ExitNode) onAuthError func(statusCode int, message string) // Callback for auth errors @@ -159,7 +158,7 @@ func (c *Client) OnAuthError(callback func(statusCode int, message string)) { } // NewClient creates a new websocket client -func NewClient(ID, secret, userToken, orgId, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) { +func NewClient(ID, secret, userToken, orgId, endpoint string, pingInterval time.Duration, opts ...ClientOption) (*Client, error) { config := &Config{ ID: ID, Secret: secret, @@ -176,7 +175,6 @@ func NewClient(ID, secret, userToken, orgId, endpoint string, pingInterval time. reconnectInterval: 3 * time.Second, isConnected: false, pingInterval: pingInterval, - pingTimeout: pingTimeout, clientType: "olm", pingDone: make(chan struct{}), } From 703c606af566c34445ccb8f7b995183918e1b6aa Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 16 Mar 2026 14:31:16 -0700 Subject: [PATCH 15/15] Handle no chainId case --- main.go | 2 +- olm/peer.go | 22 +++++++++++++++++++--- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/main.go b/main.go index 2bf8dcd..6ea7d11 100644 --- a/main.go +++ b/main.go @@ -190,7 +190,7 @@ func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCt os.Exit(0) } - olmVersion := "version_replaceme" + olmVersion := "1.4.3" if showVersion { fmt.Println("Olm version " + olmVersion) os.Exit(0) diff --git a/olm/peer.go b/olm/peer.go index 4a9a54c..fca47b5 100644 --- a/olm/peer.go +++ b/olm/peer.go @@ -42,8 +42,16 @@ func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) { delete(o.stopPeerSends, siteConfigMsg.ChainId) } o.peerSendMu.Unlock() + } else { + // stop all of the stopPeerSends + o.peerSendMu.Lock() + for _, stop := range o.stopPeerSends { + stop() + } + o.stopPeerSends = make(map[string]func()) + o.peerSendMu.Unlock() } - + if siteConfigMsg.PublicKey == "" { logger.Warn("Skipping add-peer for site %d (%s): no public key available (site may not be connected)", siteConfigMsg.SiteId, siteConfigMsg.Name) return @@ -190,7 +198,7 @@ func (o *Olm) handleWgPeerRelay(msg websocket.WSMessage) { } primaryRelay, err := util.ResolveDomainUpstream(relayData.RelayEndpoint, o.tunnelConfig.PublicDNS) - + if err != nil { logger.Error("Failed to resolve primary relay endpoint: %v", err) return @@ -231,7 +239,7 @@ func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) { } primaryRelay, err := util.ResolveDomainUpstream(relayData.Endpoint, o.tunnelConfig.PublicDNS) - + if err != nil { logger.Warn("Failed to resolve primary relay endpoint: %v", err) } @@ -283,6 +291,14 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) { // pending entry so the site can be re-triggered if needed in the future. delete(o.jitPendingSites, handshakeData.SiteId) o.peerSendMu.Unlock() + } else { + // Stop all of the stopPeerInits + o.peerSendMu.Lock() + for _, stop := range o.stopPeerInits { + stop() + } + o.stopPeerInits = make(map[string]func()) + o.peerSendMu.Unlock() } // Get existing peer from PeerManager