Alias jit handler

This commit is contained in:
Owen
2026-03-11 15:56:51 -07:00
parent e2690bcc03
commit 22cd02ae15
4 changed files with 72 additions and 4 deletions

View File

@@ -45,6 +45,11 @@ type DNSProxy struct {
tunnelActivePorts map[uint16]bool tunnelActivePorts map[uint16]bool
tunnelPortsLock sync.Mutex tunnelPortsLock sync.Mutex
// jitHandler is called when a local record is resolved for a site that may not be
// connected yet, giving the caller a chance to initiate a JIT connection.
// It is invoked asynchronously so it never blocks DNS resolution.
jitHandler func(siteId int)
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
wg sync.WaitGroup wg sync.WaitGroup
@@ -384,6 +389,16 @@ func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clie
response = p.checkLocalRecords(msg, question) response = p.checkLocalRecords(msg, question)
} }
// If a local A/AAAA record was found, notify the JIT handler so that the owning
// site can be connected on-demand if it is not yet active.
if response != nil && p.jitHandler != nil &&
(question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA) {
if siteId, ok := p.recordStore.GetSiteIdForDomain(question.Name); ok && siteId != 0 {
handler := p.jitHandler
go handler(siteId)
}
}
// If no local records, forward to upstream // If no local records, forward to upstream
if response == nil { if response == nil {
logger.Debug("No local record for %s, forwarding upstream to %v", question.Name, p.upstreamDNS) logger.Debug("No local record for %s, forwarding upstream to %v", question.Name, p.upstreamDNS)
@@ -718,6 +733,14 @@ func (p *DNSProxy) runPacketSender() {
} }
} }
// SetJITHandler registers a callback that is invoked whenever a local DNS record is
// resolved for an A or AAAA query. The siteId identifies which site owns the record.
// The handler is called in its own goroutine so it must be safe to call concurrently.
// Pass nil to disable JIT notifications.
func (p *DNSProxy) SetJITHandler(handler func(siteId int)) {
p.jitHandler = handler
}
// AddDNSRecord adds a DNS record to the local store // AddDNSRecord adds a DNS record to the local store
// domain should be a domain name (e.g., "example.com" or "example.com.") // domain should be a domain name (e.g., "example.com" or "example.com.")
// ip should be a valid IPv4 or IPv6 address // ip should be a valid IPv4 or IPv6 address

View File

@@ -7,6 +7,7 @@ import (
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/fosrl/newt/logger" "github.com/fosrl/newt/logger"
"github.com/fosrl/newt/network" "github.com/fosrl/newt/network"
@@ -196,6 +197,36 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) {
logger.Error("Failed to start DNS proxy: %v", err) logger.Error("Failed to start DNS proxy: %v", err)
} }
// Register JIT handler: when the DNS proxy resolves a local record, check whether
// the owning site is already connected and, if not, initiate a JIT connection.
o.dnsProxy.SetJITHandler(func(siteId int) {
if o.peerManager == nil || o.websocket == nil {
return
}
// Site already has an active peer connection - nothing to do.
if _, exists := o.peerManager.GetPeer(siteId); exists {
return
}
o.peerSendMu.Lock()
defer o.peerSendMu.Unlock()
// A JIT request for this site is already in-flight - avoid duplicate sends.
if _, pending := o.jitPendingSites[siteId]; pending {
return
}
chainId := generateChainId()
logger.Info("DNS-triggered JIT connect for site %d (chainId=%s)", siteId, chainId)
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/init", map[string]interface{}{
"siteId": siteId,
"chainId": chainId,
}, 2*time.Second, 10)
o.stopPeerInits[chainId] = stopFunc
o.jitPendingSites[siteId] = chainId
})
if o.tunnelConfig.OverrideDNS { if o.tunnelConfig.OverrideDNS {
// Set up DNS override to use our DNS proxy // Set up DNS override to use our DNS proxy
if err := dnsOverride.SetupDNSOverride(o.tunnelConfig.InterfaceName, o.dnsProxy.GetProxyIP()); err != nil { if err := dnsOverride.SetupDNSOverride(o.tunnelConfig.InterfaceName, o.dnsProxy.GetProxyIP()); err != nil {

View File

@@ -69,6 +69,7 @@ type Olm struct {
stopPeerSends map[string]func() stopPeerSends map[string]func()
stopPeerInits map[string]func() stopPeerInits map[string]func()
jitPendingSites map[int]string // siteId -> chainId for in-flight JIT requests
peerSendMu sync.Mutex peerSendMu sync.Mutex
// WaitGroup to track tunnel lifecycle // WaitGroup to track tunnel lifecycle
@@ -183,6 +184,7 @@ func Init(ctx context.Context, config OlmConfig) (*Olm, error) {
olmConfig: config, olmConfig: config,
stopPeerSends: make(map[string]func()), stopPeerSends: make(map[string]func()),
stopPeerInits: make(map[string]func()), stopPeerInits: make(map[string]func()),
jitPendingSites: make(map[int]string),
} }
newOlm.registerAPICallbacks() newOlm.registerAPICallbacks()
@@ -560,6 +562,7 @@ func (o *Olm) Close() {
} }
} }
o.stopPeerSends = make(map[string]func()) o.stopPeerSends = make(map[string]func())
o.jitPendingSites = make(map[int]string)
o.peerSendMu.Unlock() o.peerSendMu.Unlock()
// send a disconnect message to the cloud to show disconnected // send a disconnect message to the cloud to show disconnected

View File

@@ -272,6 +272,9 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
stop() stop()
delete(o.stopPeerInits, handshakeData.ChainId) delete(o.stopPeerInits, handshakeData.ChainId)
} }
// If this chain was initiated by a DNS-triggered JIT request, clear the
// pending entry so the site can be re-triggered if needed in the future.
delete(o.jitPendingSites, handshakeData.SiteId)
o.peerSendMu.Unlock() o.peerSendMu.Unlock()
} }
@@ -353,6 +356,14 @@ func (o *Olm) handleCancelChain(msg websocket.WSMessage) {
delete(o.stopPeerInits, cancelData.ChainId) delete(o.stopPeerInits, cancelData.ChainId)
found = true found = true
} }
// If this chain was a DNS-triggered JIT request, clear the pending entry so
// the site can be re-triggered on the next DNS lookup.
for siteId, chainId := range o.jitPendingSites {
if chainId == cancelData.ChainId {
delete(o.jitPendingSites, siteId)
break
}
}
if stop, ok := o.stopPeerSends[cancelData.ChainId]; ok { if stop, ok := o.stopPeerSends[cancelData.ChainId]; ok {
stop() stop()