diff --git a/dns/dns_proxy.go b/dns/dns_proxy.go index 7b7858c..7a69f53 100644 --- a/dns/dns_proxy.go +++ b/dns/dns_proxy.go @@ -45,6 +45,11 @@ type DNSProxy struct { tunnelActivePorts map[uint16]bool tunnelPortsLock sync.Mutex + // jitHandler is called when a local record is resolved for a site that may not be + // connected yet, giving the caller a chance to initiate a JIT connection. + // It is invoked asynchronously so it never blocks DNS resolution. + jitHandler func(siteId int) + ctx context.Context cancel context.CancelFunc wg sync.WaitGroup @@ -384,6 +389,16 @@ func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clie response = p.checkLocalRecords(msg, question) } + // If a local A/AAAA record was found, notify the JIT handler so that the owning + // site can be connected on-demand if it is not yet active. + if response != nil && p.jitHandler != nil && + (question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA) { + if siteId, ok := p.recordStore.GetSiteIdForDomain(question.Name); ok && siteId != 0 { + handler := p.jitHandler + go handler(siteId) + } + } + // If no local records, forward to upstream if response == nil { logger.Debug("No local record for %s, forwarding upstream to %v", question.Name, p.upstreamDNS) @@ -718,6 +733,14 @@ func (p *DNSProxy) runPacketSender() { } } +// SetJITHandler registers a callback that is invoked whenever a local DNS record is +// resolved for an A or AAAA query. The siteId identifies which site owns the record. +// The handler is called in its own goroutine so it must be safe to call concurrently. +// Pass nil to disable JIT notifications. +func (p *DNSProxy) SetJITHandler(handler func(siteId int)) { + p.jitHandler = handler +} + // AddDNSRecord adds a DNS record to the local store // domain should be a domain name (e.g., "example.com" or "example.com.") // ip should be a valid IPv4 or IPv6 address diff --git a/olm/connect.go b/olm/connect.go index dc05d1f..1e00ee2 100644 --- a/olm/connect.go +++ b/olm/connect.go @@ -7,6 +7,7 @@ import ( "runtime" "strconv" "strings" + "time" "github.com/fosrl/newt/logger" "github.com/fosrl/newt/network" @@ -196,6 +197,36 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) { logger.Error("Failed to start DNS proxy: %v", err) } + // Register JIT handler: when the DNS proxy resolves a local record, check whether + // the owning site is already connected and, if not, initiate a JIT connection. + o.dnsProxy.SetJITHandler(func(siteId int) { + if o.peerManager == nil || o.websocket == nil { + return + } + + // Site already has an active peer connection - nothing to do. + if _, exists := o.peerManager.GetPeer(siteId); exists { + return + } + + o.peerSendMu.Lock() + defer o.peerSendMu.Unlock() + + // A JIT request for this site is already in-flight - avoid duplicate sends. + if _, pending := o.jitPendingSites[siteId]; pending { + return + } + + chainId := generateChainId() + logger.Info("DNS-triggered JIT connect for site %d (chainId=%s)", siteId, chainId) + stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/init", map[string]interface{}{ + "siteId": siteId, + "chainId": chainId, + }, 2*time.Second, 10) + o.stopPeerInits[chainId] = stopFunc + o.jitPendingSites[siteId] = chainId + }) + if o.tunnelConfig.OverrideDNS { // Set up DNS override to use our DNS proxy if err := dnsOverride.SetupDNSOverride(o.tunnelConfig.InterfaceName, o.dnsProxy.GetProxyIP()); err != nil { diff --git a/olm/olm.go b/olm/olm.go index b2843d2..8b01f9d 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -67,8 +67,9 @@ type Olm struct { stopRegister func() updateRegister func(newData any) - stopPeerSends map[string]func() - stopPeerInits map[string]func() + stopPeerSends map[string]func() + stopPeerInits map[string]func() + jitPendingSites map[int]string // siteId -> chainId for in-flight JIT requests peerSendMu sync.Mutex // WaitGroup to track tunnel lifecycle @@ -181,8 +182,9 @@ func Init(ctx context.Context, config OlmConfig) (*Olm, error) { olmCtx: ctx, apiServer: apiServer, olmConfig: config, - stopPeerSends: make(map[string]func()), - stopPeerInits: make(map[string]func()), + stopPeerSends: make(map[string]func()), + stopPeerInits: make(map[string]func()), + jitPendingSites: make(map[int]string), } newOlm.registerAPICallbacks() @@ -560,6 +562,7 @@ func (o *Olm) Close() { } } o.stopPeerSends = make(map[string]func()) + o.jitPendingSites = make(map[int]string) o.peerSendMu.Unlock() // send a disconnect message to the cloud to show disconnected diff --git a/olm/peer.go b/olm/peer.go index 9f02bb2..da5a884 100644 --- a/olm/peer.go +++ b/olm/peer.go @@ -272,6 +272,9 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) { stop() delete(o.stopPeerInits, handshakeData.ChainId) } + // If this chain was initiated by a DNS-triggered JIT request, clear the + // pending entry so the site can be re-triggered if needed in the future. + delete(o.jitPendingSites, handshakeData.SiteId) o.peerSendMu.Unlock() } @@ -353,6 +356,14 @@ func (o *Olm) handleCancelChain(msg websocket.WSMessage) { delete(o.stopPeerInits, cancelData.ChainId) found = true } + // If this chain was a DNS-triggered JIT request, clear the pending entry so + // the site can be re-triggered on the next DNS lookup. + for siteId, chainId := range o.jitPendingSites { + if chainId == cancelData.ChainId { + delete(o.jitPendingSites, siteId) + break + } + } if stop, ok := o.stopPeerSends[cancelData.ChainId]; ok { stop()