mirror of
https://github.com/fosrl/olm.git
synced 2026-03-13 14:16:41 +00:00
Alias jit handler
This commit is contained in:
@@ -45,6 +45,11 @@ type DNSProxy struct {
|
|||||||
tunnelActivePorts map[uint16]bool
|
tunnelActivePorts map[uint16]bool
|
||||||
tunnelPortsLock sync.Mutex
|
tunnelPortsLock sync.Mutex
|
||||||
|
|
||||||
|
// jitHandler is called when a local record is resolved for a site that may not be
|
||||||
|
// connected yet, giving the caller a chance to initiate a JIT connection.
|
||||||
|
// It is invoked asynchronously so it never blocks DNS resolution.
|
||||||
|
jitHandler func(siteId int)
|
||||||
|
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
@@ -384,6 +389,16 @@ func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clie
|
|||||||
response = p.checkLocalRecords(msg, question)
|
response = p.checkLocalRecords(msg, question)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If a local A/AAAA record was found, notify the JIT handler so that the owning
|
||||||
|
// site can be connected on-demand if it is not yet active.
|
||||||
|
if response != nil && p.jitHandler != nil &&
|
||||||
|
(question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA) {
|
||||||
|
if siteId, ok := p.recordStore.GetSiteIdForDomain(question.Name); ok && siteId != 0 {
|
||||||
|
handler := p.jitHandler
|
||||||
|
go handler(siteId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// If no local records, forward to upstream
|
// If no local records, forward to upstream
|
||||||
if response == nil {
|
if response == nil {
|
||||||
logger.Debug("No local record for %s, forwarding upstream to %v", question.Name, p.upstreamDNS)
|
logger.Debug("No local record for %s, forwarding upstream to %v", question.Name, p.upstreamDNS)
|
||||||
@@ -718,6 +733,14 @@ func (p *DNSProxy) runPacketSender() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetJITHandler registers a callback that is invoked whenever a local DNS record is
|
||||||
|
// resolved for an A or AAAA query. The siteId identifies which site owns the record.
|
||||||
|
// The handler is called in its own goroutine so it must be safe to call concurrently.
|
||||||
|
// Pass nil to disable JIT notifications.
|
||||||
|
func (p *DNSProxy) SetJITHandler(handler func(siteId int)) {
|
||||||
|
p.jitHandler = handler
|
||||||
|
}
|
||||||
|
|
||||||
// AddDNSRecord adds a DNS record to the local store
|
// AddDNSRecord adds a DNS record to the local store
|
||||||
// domain should be a domain name (e.g., "example.com" or "example.com.")
|
// domain should be a domain name (e.g., "example.com" or "example.com.")
|
||||||
// ip should be a valid IPv4 or IPv6 address
|
// ip should be a valid IPv4 or IPv6 address
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"github.com/fosrl/newt/network"
|
"github.com/fosrl/newt/network"
|
||||||
@@ -196,6 +197,36 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) {
|
|||||||
logger.Error("Failed to start DNS proxy: %v", err)
|
logger.Error("Failed to start DNS proxy: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Register JIT handler: when the DNS proxy resolves a local record, check whether
|
||||||
|
// the owning site is already connected and, if not, initiate a JIT connection.
|
||||||
|
o.dnsProxy.SetJITHandler(func(siteId int) {
|
||||||
|
if o.peerManager == nil || o.websocket == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Site already has an active peer connection - nothing to do.
|
||||||
|
if _, exists := o.peerManager.GetPeer(siteId); exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
o.peerSendMu.Lock()
|
||||||
|
defer o.peerSendMu.Unlock()
|
||||||
|
|
||||||
|
// A JIT request for this site is already in-flight - avoid duplicate sends.
|
||||||
|
if _, pending := o.jitPendingSites[siteId]; pending {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
chainId := generateChainId()
|
||||||
|
logger.Info("DNS-triggered JIT connect for site %d (chainId=%s)", siteId, chainId)
|
||||||
|
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/init", map[string]interface{}{
|
||||||
|
"siteId": siteId,
|
||||||
|
"chainId": chainId,
|
||||||
|
}, 2*time.Second, 10)
|
||||||
|
o.stopPeerInits[chainId] = stopFunc
|
||||||
|
o.jitPendingSites[siteId] = chainId
|
||||||
|
})
|
||||||
|
|
||||||
if o.tunnelConfig.OverrideDNS {
|
if o.tunnelConfig.OverrideDNS {
|
||||||
// Set up DNS override to use our DNS proxy
|
// Set up DNS override to use our DNS proxy
|
||||||
if err := dnsOverride.SetupDNSOverride(o.tunnelConfig.InterfaceName, o.dnsProxy.GetProxyIP()); err != nil {
|
if err := dnsOverride.SetupDNSOverride(o.tunnelConfig.InterfaceName, o.dnsProxy.GetProxyIP()); err != nil {
|
||||||
|
|||||||
@@ -69,6 +69,7 @@ type Olm struct {
|
|||||||
|
|
||||||
stopPeerSends map[string]func()
|
stopPeerSends map[string]func()
|
||||||
stopPeerInits map[string]func()
|
stopPeerInits map[string]func()
|
||||||
|
jitPendingSites map[int]string // siteId -> chainId for in-flight JIT requests
|
||||||
peerSendMu sync.Mutex
|
peerSendMu sync.Mutex
|
||||||
|
|
||||||
// WaitGroup to track tunnel lifecycle
|
// WaitGroup to track tunnel lifecycle
|
||||||
@@ -183,6 +184,7 @@ func Init(ctx context.Context, config OlmConfig) (*Olm, error) {
|
|||||||
olmConfig: config,
|
olmConfig: config,
|
||||||
stopPeerSends: make(map[string]func()),
|
stopPeerSends: make(map[string]func()),
|
||||||
stopPeerInits: make(map[string]func()),
|
stopPeerInits: make(map[string]func()),
|
||||||
|
jitPendingSites: make(map[int]string),
|
||||||
}
|
}
|
||||||
|
|
||||||
newOlm.registerAPICallbacks()
|
newOlm.registerAPICallbacks()
|
||||||
@@ -560,6 +562,7 @@ func (o *Olm) Close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
o.stopPeerSends = make(map[string]func())
|
o.stopPeerSends = make(map[string]func())
|
||||||
|
o.jitPendingSites = make(map[int]string)
|
||||||
o.peerSendMu.Unlock()
|
o.peerSendMu.Unlock()
|
||||||
|
|
||||||
// send a disconnect message to the cloud to show disconnected
|
// send a disconnect message to the cloud to show disconnected
|
||||||
|
|||||||
11
olm/peer.go
11
olm/peer.go
@@ -272,6 +272,9 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
|
|||||||
stop()
|
stop()
|
||||||
delete(o.stopPeerInits, handshakeData.ChainId)
|
delete(o.stopPeerInits, handshakeData.ChainId)
|
||||||
}
|
}
|
||||||
|
// If this chain was initiated by a DNS-triggered JIT request, clear the
|
||||||
|
// pending entry so the site can be re-triggered if needed in the future.
|
||||||
|
delete(o.jitPendingSites, handshakeData.SiteId)
|
||||||
o.peerSendMu.Unlock()
|
o.peerSendMu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -353,6 +356,14 @@ func (o *Olm) handleCancelChain(msg websocket.WSMessage) {
|
|||||||
delete(o.stopPeerInits, cancelData.ChainId)
|
delete(o.stopPeerInits, cancelData.ChainId)
|
||||||
found = true
|
found = true
|
||||||
}
|
}
|
||||||
|
// If this chain was a DNS-triggered JIT request, clear the pending entry so
|
||||||
|
// the site can be re-triggered on the next DNS lookup.
|
||||||
|
for siteId, chainId := range o.jitPendingSites {
|
||||||
|
if chainId == cancelData.ChainId {
|
||||||
|
delete(o.jitPendingSites, siteId)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if stop, ok := o.stopPeerSends[cancelData.ChainId]; ok {
|
if stop, ok := o.stopPeerSends[cancelData.ChainId]; ok {
|
||||||
stop()
|
stop()
|
||||||
|
|||||||
Reference in New Issue
Block a user