mirror of
https://github.com/fosrl/olm.git
synced 2026-03-13 14:16:41 +00:00
Alias jit handler
This commit is contained in:
@@ -45,6 +45,11 @@ type DNSProxy struct {
|
||||
tunnelActivePorts map[uint16]bool
|
||||
tunnelPortsLock sync.Mutex
|
||||
|
||||
// jitHandler is called when a local record is resolved for a site that may not be
|
||||
// connected yet, giving the caller a chance to initiate a JIT connection.
|
||||
// It is invoked asynchronously so it never blocks DNS resolution.
|
||||
jitHandler func(siteId int)
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
@@ -384,6 +389,16 @@ func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clie
|
||||
response = p.checkLocalRecords(msg, question)
|
||||
}
|
||||
|
||||
// If a local A/AAAA record was found, notify the JIT handler so that the owning
|
||||
// site can be connected on-demand if it is not yet active.
|
||||
if response != nil && p.jitHandler != nil &&
|
||||
(question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA) {
|
||||
if siteId, ok := p.recordStore.GetSiteIdForDomain(question.Name); ok && siteId != 0 {
|
||||
handler := p.jitHandler
|
||||
go handler(siteId)
|
||||
}
|
||||
}
|
||||
|
||||
// If no local records, forward to upstream
|
||||
if response == nil {
|
||||
logger.Debug("No local record for %s, forwarding upstream to %v", question.Name, p.upstreamDNS)
|
||||
@@ -718,6 +733,14 @@ func (p *DNSProxy) runPacketSender() {
|
||||
}
|
||||
}
|
||||
|
||||
// SetJITHandler registers a callback that is invoked whenever a local DNS record is
|
||||
// resolved for an A or AAAA query. The siteId identifies which site owns the record.
|
||||
// The handler is called in its own goroutine so it must be safe to call concurrently.
|
||||
// Pass nil to disable JIT notifications.
|
||||
func (p *DNSProxy) SetJITHandler(handler func(siteId int)) {
|
||||
p.jitHandler = handler
|
||||
}
|
||||
|
||||
// AddDNSRecord adds a DNS record to the local store
|
||||
// domain should be a domain name (e.g., "example.com" or "example.com.")
|
||||
// ip should be a valid IPv4 or IPv6 address
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/network"
|
||||
@@ -196,6 +197,36 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) {
|
||||
logger.Error("Failed to start DNS proxy: %v", err)
|
||||
}
|
||||
|
||||
// Register JIT handler: when the DNS proxy resolves a local record, check whether
|
||||
// the owning site is already connected and, if not, initiate a JIT connection.
|
||||
o.dnsProxy.SetJITHandler(func(siteId int) {
|
||||
if o.peerManager == nil || o.websocket == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Site already has an active peer connection - nothing to do.
|
||||
if _, exists := o.peerManager.GetPeer(siteId); exists {
|
||||
return
|
||||
}
|
||||
|
||||
o.peerSendMu.Lock()
|
||||
defer o.peerSendMu.Unlock()
|
||||
|
||||
// A JIT request for this site is already in-flight - avoid duplicate sends.
|
||||
if _, pending := o.jitPendingSites[siteId]; pending {
|
||||
return
|
||||
}
|
||||
|
||||
chainId := generateChainId()
|
||||
logger.Info("DNS-triggered JIT connect for site %d (chainId=%s)", siteId, chainId)
|
||||
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/init", map[string]interface{}{
|
||||
"siteId": siteId,
|
||||
"chainId": chainId,
|
||||
}, 2*time.Second, 10)
|
||||
o.stopPeerInits[chainId] = stopFunc
|
||||
o.jitPendingSites[siteId] = chainId
|
||||
})
|
||||
|
||||
if o.tunnelConfig.OverrideDNS {
|
||||
// Set up DNS override to use our DNS proxy
|
||||
if err := dnsOverride.SetupDNSOverride(o.tunnelConfig.InterfaceName, o.dnsProxy.GetProxyIP()); err != nil {
|
||||
|
||||
11
olm/olm.go
11
olm/olm.go
@@ -67,8 +67,9 @@ type Olm struct {
|
||||
stopRegister func()
|
||||
updateRegister func(newData any)
|
||||
|
||||
stopPeerSends map[string]func()
|
||||
stopPeerInits map[string]func()
|
||||
stopPeerSends map[string]func()
|
||||
stopPeerInits map[string]func()
|
||||
jitPendingSites map[int]string // siteId -> chainId for in-flight JIT requests
|
||||
peerSendMu sync.Mutex
|
||||
|
||||
// WaitGroup to track tunnel lifecycle
|
||||
@@ -181,8 +182,9 @@ func Init(ctx context.Context, config OlmConfig) (*Olm, error) {
|
||||
olmCtx: ctx,
|
||||
apiServer: apiServer,
|
||||
olmConfig: config,
|
||||
stopPeerSends: make(map[string]func()),
|
||||
stopPeerInits: make(map[string]func()),
|
||||
stopPeerSends: make(map[string]func()),
|
||||
stopPeerInits: make(map[string]func()),
|
||||
jitPendingSites: make(map[int]string),
|
||||
}
|
||||
|
||||
newOlm.registerAPICallbacks()
|
||||
@@ -560,6 +562,7 @@ func (o *Olm) Close() {
|
||||
}
|
||||
}
|
||||
o.stopPeerSends = make(map[string]func())
|
||||
o.jitPendingSites = make(map[int]string)
|
||||
o.peerSendMu.Unlock()
|
||||
|
||||
// send a disconnect message to the cloud to show disconnected
|
||||
|
||||
11
olm/peer.go
11
olm/peer.go
@@ -272,6 +272,9 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
|
||||
stop()
|
||||
delete(o.stopPeerInits, handshakeData.ChainId)
|
||||
}
|
||||
// If this chain was initiated by a DNS-triggered JIT request, clear the
|
||||
// pending entry so the site can be re-triggered if needed in the future.
|
||||
delete(o.jitPendingSites, handshakeData.SiteId)
|
||||
o.peerSendMu.Unlock()
|
||||
}
|
||||
|
||||
@@ -353,6 +356,14 @@ func (o *Olm) handleCancelChain(msg websocket.WSMessage) {
|
||||
delete(o.stopPeerInits, cancelData.ChainId)
|
||||
found = true
|
||||
}
|
||||
// If this chain was a DNS-triggered JIT request, clear the pending entry so
|
||||
// the site can be re-triggered on the next DNS lookup.
|
||||
for siteId, chainId := range o.jitPendingSites {
|
||||
if chainId == cancelData.ChainId {
|
||||
delete(o.jitPendingSites, siteId)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if stop, ok := o.stopPeerSends[cancelData.ChainId]; ok {
|
||||
stop()
|
||||
|
||||
Reference in New Issue
Block a user