Compare commits

..

18 Commits
v1.4.3 ... jit

Author SHA1 Message Date
Owen
c2b5ef96a4 Jit of aliases working 2026-03-12 17:26:46 -07:00
Owen
e326da3d3e Merge branch 'dev' into jit 2026-03-12 16:53:16 -07:00
Owen
53def4e2f6 Merge branch 'main' into dev 2026-03-12 16:51:06 -07:00
Owen
e85fd9d71e Bump newt version 2026-03-12 16:50:41 -07:00
Owen
98a24960f5 Remove extra restore function 2026-03-12 16:50:41 -07:00
Owen
e82387d515 Actually pull the upstream from the dns var 2026-03-12 16:50:41 -07:00
Owen
b3cb3e1c92 Add hardcoded public dns 2026-03-12 16:50:41 -07:00
Owen
22cd02ae15 Alias jit handler 2026-03-11 15:56:51 -07:00
André Gilerson
3f258d3500 Fix crash when peer has nil publicKey in site config
Skip sites with empty/nil publicKey instead of passing them to the
WireGuard UAPI layer, which expects a valid 64-char hex string. A nil
key occurs when a Newt site has never connected. Previously this caused
all sites to fail with "hex string does not fit the slice".
2026-03-07 20:44:25 -08:00
Owen
e2690bcc03 Store site id 2026-03-06 16:19:00 -08:00
Owen
f2d0e6a14c Merge branch 'dev' into jit 2026-03-06 16:08:24 -08:00
Laurence
ae88766d85 test(dns): add dns test cases for nodata 2026-03-06 16:08:01 -08:00
Laurence
9ae49e36d5 refactor(dns): simplify DNSRecordStore from trie to map
Replace trie-based domain lookup with simple map for O(1) lookups.
  Add exists boolean to GetRecords for proper NODATA vs NXDOMAIN responses.
2026-03-06 16:08:01 -08:00
Laurence
5ca4825800 refactor(dns): trie + unified record set for DNSRecordStore
- Replace four maps (aRecords, aaaaRecords, aWildcards, aaaaWildcards) with a label trie for exact lookups and a single wildcards map
- Store one recordSet (A + AAAA) per domain/pattern instead of separate A and AAAA maps
- Exact lookups O(labels); PTR unchanged (map); API and behaviour unchanged
2026-03-06 16:08:01 -08:00
Owen
809dbe77de Make chainId in relay message bckwd compat 2026-03-06 15:27:03 -08:00
Owen
c67c2a60a1 Handle canceling sends for relay 2026-03-06 15:15:31 -08:00
Owen
051c0fdfd8 Working jit with chain ids 2026-03-04 17:51:48 -08:00
Owen
e7507e0837 Add api endpoints to jit 2026-03-04 17:01:17 -08:00
15 changed files with 892 additions and 332 deletions

View File

@@ -78,6 +78,13 @@ type MetadataChangeRequest struct {
Postures map[string]any `json:"postures"`
}
// JITConnectionRequest defines the structure for a dynamic Just-In-Time connection request.
// Either SiteID or ResourceID must be provided (but not necessarily both).
type JITConnectionRequest struct {
Site string `json:"site,omitempty"`
Resource string `json:"resource,omitempty"`
}
// API represents the HTTP server and its state
type API struct {
addr string
@@ -92,6 +99,7 @@ type API struct {
onExit func() error
onRebind func() error
onPowerMode func(PowerModeRequest) error
onJITConnect func(JITConnectionRequest) error
statusMu sync.RWMutex
peerStatuses map[int]*PeerStatus
@@ -143,6 +151,7 @@ func (s *API) SetHandlers(
onExit func() error,
onRebind func() error,
onPowerMode func(PowerModeRequest) error,
onJITConnect func(JITConnectionRequest) error,
) {
s.onConnect = onConnect
s.onSwitchOrg = onSwitchOrg
@@ -151,6 +160,7 @@ func (s *API) SetHandlers(
s.onExit = onExit
s.onRebind = onRebind
s.onPowerMode = onPowerMode
s.onJITConnect = onJITConnect
}
// Start starts the HTTP server
@@ -169,6 +179,7 @@ func (s *API) Start() error {
mux.HandleFunc("/health", s.handleHealth)
mux.HandleFunc("/rebind", s.handleRebind)
mux.HandleFunc("/power-mode", s.handlePowerMode)
mux.HandleFunc("/jit-connect", s.handleJITConnect)
s.server = &http.Server{
Handler: mux,
@@ -633,6 +644,54 @@ func (s *API) handleRebind(w http.ResponseWriter, r *http.Request) {
})
}
// handleJITConnect handles the /jit-connect endpoint.
// It initiates a dynamic Just-In-Time connection to a site identified by either
// a site or a resource. Exactly one of the two must be provided.
func (s *API) handleJITConnect(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req JITConnectionRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest)
return
}
// Validate that exactly one of site or resource is provided
if req.Site == "" && req.Resource == "" {
http.Error(w, "Missing required field: either site or resource must be provided", http.StatusBadRequest)
return
}
if req.Site != "" && req.Resource != "" {
http.Error(w, "Ambiguous request: provide either site or resource, not both", http.StatusBadRequest)
return
}
if req.Site != "" {
logger.Info("Received JIT connection request via API: site=%s", req.Site)
} else {
logger.Info("Received JIT connection request via API: resource=%s", req.Resource)
}
if s.onJITConnect != nil {
if err := s.onJITConnect(req); err != nil {
http.Error(w, fmt.Sprintf("JIT connection failed: %v", err), http.StatusInternalServerError)
return
}
} else {
http.Error(w, "JIT connect handler not configured", http.StatusNotImplemented)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusAccepted)
_ = json.NewEncoder(w).Encode(map[string]string{
"status": "JIT connection request accepted",
})
}
// handlePowerMode handles the /power-mode endpoint
// This allows changing the power mode between "normal" and "low"
func (s *API) handlePowerMode(w http.ResponseWriter, r *http.Request) {

View File

@@ -45,6 +45,11 @@ type DNSProxy struct {
tunnelActivePorts map[uint16]bool
tunnelPortsLock sync.Mutex
// jitHandler is called when a local record is resolved for a site that may not be
// connected yet, giving the caller a chance to initiate a JIT connection.
// It is invoked asynchronously so it never blocks DNS resolution.
jitHandler func(siteId int)
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
@@ -384,6 +389,16 @@ func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clie
response = p.checkLocalRecords(msg, question)
}
// If a local A/AAAA record was found, notify the JIT handler so that the owning
// site can be connected on-demand if it is not yet active.
if response != nil && p.jitHandler != nil &&
(question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA) {
if siteId, ok := p.recordStore.GetSiteIdForDomain(question.Name); ok && siteId != 0 {
handler := p.jitHandler
go handler(siteId)
}
}
// If no local records, forward to upstream
if response == nil {
logger.Debug("No local record for %s, forwarding upstream to %v", question.Name, p.upstreamDNS)
@@ -447,19 +462,20 @@ func (p *DNSProxy) checkLocalRecords(query *dns.Msg, question dns.Question) *dns
return nil
}
ips := p.recordStore.GetRecords(question.Name, recordType)
if len(ips) == 0 {
ips, exists := p.recordStore.GetRecords(question.Name, recordType)
if !exists {
// Domain not found in local records, forward to upstream
return nil
}
logger.Debug("Found %d local record(s) for %s", len(ips), question.Name)
// Create response message
// Create response message (NODATA if no records, otherwise with answers)
response := new(dns.Msg)
response.SetReply(query)
response.Authoritative = true
// Add answer records
// Add answer records (loop is a no-op if ips is empty)
for _, ip := range ips {
var rr dns.RR
if question.Qtype == dns.TypeA {
@@ -717,11 +733,20 @@ func (p *DNSProxy) runPacketSender() {
}
}
// SetJITHandler registers a callback that is invoked whenever a local DNS record is
// resolved for an A or AAAA query. The siteId identifies which site owns the record.
// The handler is called in its own goroutine so it must be safe to call concurrently.
// Pass nil to disable JIT notifications.
func (p *DNSProxy) SetJITHandler(handler func(siteId int)) {
p.jitHandler = handler
}
// AddDNSRecord adds a DNS record to the local store
// domain should be a domain name (e.g., "example.com" or "example.com.")
// ip should be a valid IPv4 or IPv6 address
func (p *DNSProxy) AddDNSRecord(domain string, ip net.IP) error {
return p.recordStore.AddRecord(domain, ip)
func (p *DNSProxy) AddDNSRecord(domain string, ip net.IP, siteId int) error {
logger.Debug("Adding dns record for domain %s with IP %s (siteId=%d)", domain, ip.String(), siteId)
return p.recordStore.AddRecord(domain, ip, siteId)
}
// RemoveDNSRecord removes a DNS record from the local store
@@ -730,8 +755,9 @@ func (p *DNSProxy) RemoveDNSRecord(domain string, ip net.IP) {
p.recordStore.RemoveRecord(domain, ip)
}
// GetDNSRecords returns all IP addresses for a domain and record type
func (p *DNSProxy) GetDNSRecords(domain string, recordType RecordType) []net.IP {
// GetDNSRecords returns all IP addresses for a domain and record type.
// The second return value indicates whether the domain exists.
func (p *DNSProxy) GetDNSRecords(domain string, recordType RecordType) ([]net.IP, bool) {
return p.recordStore.GetRecords(domain, recordType)
}

178
dns/dns_proxy_test.go Normal file
View File

@@ -0,0 +1,178 @@
package dns
import (
"net"
"testing"
"github.com/miekg/dns"
)
func TestCheckLocalRecordsNODATAForAAAA(t *testing.T) {
proxy := &DNSProxy{
recordStore: NewDNSRecordStore(),
}
// Add an A record for a domain (no AAAA record)
ip := net.ParseIP("10.0.0.1")
err := proxy.recordStore.AddRecord("myservice.internal", ip, 0)
if err != nil {
t.Fatalf("Failed to add A record: %v", err)
}
// Query AAAA for domain with only A record - should return NODATA
query := new(dns.Msg)
query.SetQuestion("myservice.internal.", dns.TypeAAAA)
response := proxy.checkLocalRecords(query, query.Question[0])
if response == nil {
t.Fatal("Expected NODATA response, got nil (would forward to upstream)")
}
if response.Rcode != dns.RcodeSuccess {
t.Errorf("Expected Rcode NOERROR (0), got %d", response.Rcode)
}
if len(response.Answer) != 0 {
t.Errorf("Expected empty answer section for NODATA, got %d answers", len(response.Answer))
}
if !response.Authoritative {
t.Error("Expected response to be authoritative")
}
// Query A for same domain - should return the record
query = new(dns.Msg)
query.SetQuestion("myservice.internal.", dns.TypeA)
response = proxy.checkLocalRecords(query, query.Question[0])
if response == nil {
t.Fatal("Expected response with A record, got nil")
}
if len(response.Answer) != 1 {
t.Fatalf("Expected 1 answer, got %d", len(response.Answer))
}
aRecord, ok := response.Answer[0].(*dns.A)
if !ok {
t.Fatal("Expected A record in answer")
}
if !aRecord.A.Equal(ip.To4()) {
t.Errorf("Expected IP %v, got %v", ip.To4(), aRecord.A)
}
}
func TestCheckLocalRecordsNODATAForA(t *testing.T) {
proxy := &DNSProxy{
recordStore: NewDNSRecordStore(),
}
// Add an AAAA record for a domain (no A record)
ip := net.ParseIP("2001:db8::1")
err := proxy.recordStore.AddRecord("ipv6only.internal", ip, 0)
if err != nil {
t.Fatalf("Failed to add AAAA record: %v", err)
}
// Query A for domain with only AAAA record - should return NODATA
query := new(dns.Msg)
query.SetQuestion("ipv6only.internal.", dns.TypeA)
response := proxy.checkLocalRecords(query, query.Question[0])
if response == nil {
t.Fatal("Expected NODATA response, got nil")
}
if response.Rcode != dns.RcodeSuccess {
t.Errorf("Expected Rcode NOERROR (0), got %d", response.Rcode)
}
if len(response.Answer) != 0 {
t.Errorf("Expected empty answer section, got %d answers", len(response.Answer))
}
if !response.Authoritative {
t.Error("Expected response to be authoritative")
}
// Query AAAA for same domain - should return the record
query = new(dns.Msg)
query.SetQuestion("ipv6only.internal.", dns.TypeAAAA)
response = proxy.checkLocalRecords(query, query.Question[0])
if response == nil {
t.Fatal("Expected response with AAAA record, got nil")
}
if len(response.Answer) != 1 {
t.Fatalf("Expected 1 answer, got %d", len(response.Answer))
}
aaaaRecord, ok := response.Answer[0].(*dns.AAAA)
if !ok {
t.Fatal("Expected AAAA record in answer")
}
if !aaaaRecord.AAAA.Equal(ip) {
t.Errorf("Expected IP %v, got %v", ip, aaaaRecord.AAAA)
}
}
func TestCheckLocalRecordsNonExistentDomain(t *testing.T) {
proxy := &DNSProxy{
recordStore: NewDNSRecordStore(),
}
// Add a record so the store isn't empty
err := proxy.recordStore.AddRecord("exists.internal", net.ParseIP("10.0.0.1"), 0)
if err != nil {
t.Fatalf("Failed to add record: %v", err)
}
// Query A for non-existent domain - should return nil (forward to upstream)
query := new(dns.Msg)
query.SetQuestion("unknown.internal.", dns.TypeA)
response := proxy.checkLocalRecords(query, query.Question[0])
if response != nil {
t.Error("Expected nil for non-existent domain, got response")
}
// Query AAAA for non-existent domain - should also return nil
query = new(dns.Msg)
query.SetQuestion("unknown.internal.", dns.TypeAAAA)
response = proxy.checkLocalRecords(query, query.Question[0])
if response != nil {
t.Error("Expected nil for non-existent domain, got response")
}
}
func TestCheckLocalRecordsNODATAWildcard(t *testing.T) {
proxy := &DNSProxy{
recordStore: NewDNSRecordStore(),
}
// Add a wildcard A record (no AAAA)
ip := net.ParseIP("10.0.0.1")
err := proxy.recordStore.AddRecord("*.wildcard.internal", ip, 0)
if err != nil {
t.Fatalf("Failed to add wildcard A record: %v", err)
}
// Query AAAA for wildcard-matched domain - should return NODATA
query := new(dns.Msg)
query.SetQuestion("host.wildcard.internal.", dns.TypeAAAA)
response := proxy.checkLocalRecords(query, query.Question[0])
if response == nil {
t.Fatal("Expected NODATA response for wildcard match, got nil")
}
if response.Rcode != dns.RcodeSuccess {
t.Errorf("Expected Rcode NOERROR (0), got %d", response.Rcode)
}
if len(response.Answer) != 0 {
t.Errorf("Expected empty answer section, got %d answers", len(response.Answer))
}
// Query A for wildcard-matched domain - should return the record
query = new(dns.Msg)
query.SetQuestion("host.wildcard.internal.", dns.TypeA)
response = proxy.checkLocalRecords(query, query.Question[0])
if response == nil {
t.Fatal("Expected response with A record, got nil")
}
if len(response.Answer) != 1 {
t.Fatalf("Expected 1 answer, got %d", len(response.Answer))
}
}

View File

@@ -18,24 +18,28 @@ const (
RecordTypePTR RecordType = RecordType(dns.TypePTR)
)
// DNSRecordStore manages local DNS records for A, AAAA, and PTR queries
// recordSet holds A and AAAA records for a single domain or wildcard pattern
type recordSet struct {
A []net.IP
AAAA []net.IP
SiteId int
}
// DNSRecordStore manages local DNS records for A, AAAA, and PTR queries.
// Exact domains are stored in a map; wildcard patterns are in a separate map.
type DNSRecordStore struct {
mu sync.RWMutex
aRecords map[string][]net.IP // domain -> list of IPv4 addresses
aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses
aWildcards map[string][]net.IP // wildcard pattern -> list of IPv4 addresses
aaaaWildcards map[string][]net.IP // wildcard pattern -> list of IPv6 addresses
ptrRecords map[string]string // IP address string -> domain name
mu sync.RWMutex
exact map[string]*recordSet // normalized FQDN -> A/AAAA records
wildcards map[string]*recordSet // wildcard pattern -> A/AAAA records
ptrRecords map[string]string // IP address string -> domain name
}
// NewDNSRecordStore creates a new DNS record store
func NewDNSRecordStore() *DNSRecordStore {
return &DNSRecordStore{
aRecords: make(map[string][]net.IP),
aaaaRecords: make(map[string][]net.IP),
aWildcards: make(map[string][]net.IP),
aaaaWildcards: make(map[string][]net.IP),
ptrRecords: make(map[string]string),
exact: make(map[string]*recordSet),
wildcards: make(map[string]*recordSet),
ptrRecords: make(map[string]string),
}
}
@@ -43,47 +47,57 @@ func NewDNSRecordStore() *DNSRecordStore {
// domain should be in FQDN format (e.g., "example.com.")
// domain can contain wildcards: * (0+ chars) and ? (exactly 1 char)
// ip should be a valid IPv4 or IPv6 address
// siteId is the site that owns this alias/domain
// Automatically adds a corresponding PTR record for non-wildcard domains
func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
func (s *DNSRecordStore) AddRecord(domain string, ip net.IP, siteId int) error {
s.mu.Lock()
defer s.mu.Unlock()
// Ensure domain ends with a dot (FQDN format)
if len(domain) == 0 || domain[len(domain)-1] != '.' {
domain = domain + "."
}
// Normalize domain to lowercase FQDN
domain = strings.ToLower(dns.Fqdn(domain))
// Check if domain contains wildcards
isWildcard := strings.ContainsAny(domain, "*?")
if ip.To4() != nil {
// IPv4 address
if isWildcard {
s.aWildcards[domain] = append(s.aWildcards[domain], ip)
} else {
s.aRecords[domain] = append(s.aRecords[domain], ip)
// Automatically add PTR record for non-wildcard domains
s.ptrRecords[ip.String()] = domain
}
} else if ip.To16() != nil {
// IPv6 address
if isWildcard {
s.aaaaWildcards[domain] = append(s.aaaaWildcards[domain], ip)
} else {
s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip)
// Automatically add PTR record for non-wildcard domains
s.ptrRecords[ip.String()] = domain
}
} else {
isV4 := ip.To4() != nil
if !isV4 && ip.To16() == nil {
return &net.ParseError{Type: "IP address", Text: ip.String()}
}
// Choose the appropriate map based on whether this is a wildcard
m := s.exact
if isWildcard {
m = s.wildcards
}
if m[domain] == nil {
m[domain] = &recordSet{SiteId: siteId}
}
rs := m[domain]
if isV4 {
for _, existing := range rs.A {
if existing.Equal(ip) {
return nil
}
}
rs.A = append(rs.A, ip)
} else {
for _, existing := range rs.AAAA {
if existing.Equal(ip) {
return nil
}
}
rs.AAAA = append(rs.AAAA, ip)
}
// Add PTR record for non-wildcard domains
if !isWildcard {
s.ptrRecords[ip.String()] = domain
}
return nil
}
// AddPTRRecord adds a PTR record mapping an IP address to a domain name
// ip should be a valid IPv4 or IPv6 address
// domain should be in FQDN format (e.g., "example.com.")
@@ -112,89 +126,62 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
s.mu.Lock()
defer s.mu.Unlock()
// Ensure domain ends with a dot (FQDN format)
if len(domain) == 0 || domain[len(domain)-1] != '.' {
domain = domain + "."
}
// Normalize domain to lowercase FQDN
domain = strings.ToLower(dns.Fqdn(domain))
// Check if domain contains wildcards
isWildcard := strings.ContainsAny(domain, "*?")
if ip == nil {
// Remove all records for this domain
if isWildcard {
delete(s.aWildcards, domain)
delete(s.aaaaWildcards, domain)
} else {
// For non-wildcard domains, remove PTR records for all IPs
if ips, ok := s.aRecords[domain]; ok {
for _, ipAddr := range ips {
// Only remove PTR if it points to this domain
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ipAddr.String())
}
}
}
if ips, ok := s.aaaaRecords[domain]; ok {
for _, ipAddr := range ips {
// Only remove PTR if it points to this domain
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ipAddr.String())
}
}
}
delete(s.aRecords, domain)
delete(s.aaaaRecords, domain)
}
// Choose the appropriate map
m := s.exact
if isWildcard {
m = s.wildcards
}
rs := m[domain]
if rs == nil {
return
}
if ip == nil {
// Remove all records for this domain
if !isWildcard {
for _, ipAddr := range rs.A {
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ipAddr.String())
}
}
for _, ipAddr := range rs.AAAA {
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ipAddr.String())
}
}
}
delete(m, domain)
return
}
// Remove specific IP
if ip.To4() != nil {
// Remove specific IPv4 address
if isWildcard {
if ips, ok := s.aWildcards[domain]; ok {
s.aWildcards[domain] = removeIP(ips, ip)
if len(s.aWildcards[domain]) == 0 {
delete(s.aWildcards, domain)
}
}
} else {
if ips, ok := s.aRecords[domain]; ok {
s.aRecords[domain] = removeIP(ips, ip)
if len(s.aRecords[domain]) == 0 {
delete(s.aRecords, domain)
}
// Automatically remove PTR record if it points to this domain
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ip.String())
}
rs.A = removeIP(rs.A, ip)
if !isWildcard {
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ip.String())
}
}
} else if ip.To16() != nil {
// Remove specific IPv6 address
if isWildcard {
if ips, ok := s.aaaaWildcards[domain]; ok {
s.aaaaWildcards[domain] = removeIP(ips, ip)
if len(s.aaaaWildcards[domain]) == 0 {
delete(s.aaaaWildcards, domain)
}
}
} else {
if ips, ok := s.aaaaRecords[domain]; ok {
s.aaaaRecords[domain] = removeIP(ips, ip)
if len(s.aaaaRecords[domain]) == 0 {
delete(s.aaaaRecords, domain)
}
// Automatically remove PTR record if it points to this domain
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ip.String())
}
} else {
rs.AAAA = removeIP(rs.AAAA, ip)
if !isWildcard {
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ip.String())
}
}
}
// Clean up empty record sets
if len(rs.A) == 0 && len(rs.AAAA) == 0 {
delete(m, domain)
}
}
// RemovePTRRecord removes a PTR record for an IP address
@@ -205,61 +192,80 @@ func (s *DNSRecordStore) RemovePTRRecord(ip net.IP) {
delete(s.ptrRecords, ip.String())
}
// GetRecords returns all IP addresses for a domain and record type
// First checks for exact matches, then checks wildcard patterns
func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP {
// GetSiteIdForDomain returns the siteId associated with the given domain.
// It checks exact matches first, then wildcard patterns.
// The second return value is false if the domain is not found in local records.
func (s *DNSRecordStore) GetSiteIdForDomain(domain string) (int, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
// Normalize domain to lowercase FQDN
domain = strings.ToLower(dns.Fqdn(domain))
var records []net.IP
switch recordType {
case RecordTypeA:
// Check exact match first
if ips, ok := s.aRecords[domain]; ok {
// Return a copy to prevent external modifications
records = make([]net.IP, len(ips))
copy(records, ips)
return records
}
// Check wildcard patterns
for pattern, ips := range s.aWildcards {
if matchWildcard(pattern, domain) {
records = append(records, ips...)
}
}
if len(records) > 0 {
// Return a copy
result := make([]net.IP, len(records))
copy(result, records)
return result
}
// Check exact match first
if rs, exists := s.exact[domain]; exists {
return rs.SiteId, true
}
case RecordTypeAAAA:
// Check exact match first
if ips, ok := s.aaaaRecords[domain]; ok {
// Return a copy to prevent external modifications
records = make([]net.IP, len(ips))
copy(records, ips)
return records
}
// Check wildcard patterns
for pattern, ips := range s.aaaaWildcards {
if matchWildcard(pattern, domain) {
records = append(records, ips...)
}
}
if len(records) > 0 {
// Return a copy
result := make([]net.IP, len(records))
copy(result, records)
return result
// Check wildcard matches
for pattern, rs := range s.wildcards {
if matchWildcard(pattern, domain) {
return rs.SiteId, true
}
}
return records
return 0, false
}
// GetRecords returns all IP addresses for a domain and record type.
// The second return value indicates whether the domain exists at all
// (true = domain exists, use NODATA if no records; false = NXDOMAIN).
func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) ([]net.IP, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
domain = strings.ToLower(dns.Fqdn(domain))
// Check exact match first
if rs, exists := s.exact[domain]; exists {
var ips []net.IP
if recordType == RecordTypeA {
ips = rs.A
} else {
ips = rs.AAAA
}
if len(ips) > 0 {
out := make([]net.IP, len(ips))
copy(out, ips)
return out, true
}
// Domain exists but no records of this type
return nil, true
}
// Check wildcard matches
var records []net.IP
matched := false
for pattern, rs := range s.wildcards {
if !matchWildcard(pattern, domain) {
continue
}
matched = true
if recordType == RecordTypeA {
records = append(records, rs.A...)
} else {
records = append(records, rs.AAAA...)
}
}
if !matched {
return nil, false
}
if len(records) == 0 {
return nil, true
}
out := make([]net.IP, len(records))
copy(out, records)
return out, true
}
// GetPTRRecord returns the domain name for a PTR record query
@@ -288,34 +294,30 @@ func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
s.mu.RLock()
defer s.mu.RUnlock()
// Normalize domain to lowercase FQDN
domain = strings.ToLower(dns.Fqdn(domain))
switch recordType {
case RecordTypeA:
// Check exact match
if _, ok := s.aRecords[domain]; ok {
// Check exact match
if rs, exists := s.exact[domain]; exists {
if recordType == RecordTypeA && len(rs.A) > 0 {
return true
}
// Check wildcard patterns
for pattern := range s.aWildcards {
if matchWildcard(pattern, domain) {
return true
}
}
case RecordTypeAAAA:
// Check exact match
if _, ok := s.aaaaRecords[domain]; ok {
if recordType == RecordTypeAAAA && len(rs.AAAA) > 0 {
return true
}
// Check wildcard patterns
for pattern := range s.aaaaWildcards {
if matchWildcard(pattern, domain) {
return true
}
}
}
// Check wildcard matches
for pattern, rs := range s.wildcards {
if !matchWildcard(pattern, domain) {
continue
}
if recordType == RecordTypeA && len(rs.A) > 0 {
return true
}
if recordType == RecordTypeAAAA && len(rs.AAAA) > 0 {
return true
}
}
return false
}
@@ -339,10 +341,8 @@ func (s *DNSRecordStore) Clear() {
s.mu.Lock()
defer s.mu.Unlock()
s.aRecords = make(map[string][]net.IP)
s.aaaaRecords = make(map[string][]net.IP)
s.aWildcards = make(map[string][]net.IP)
s.aaaaWildcards = make(map[string][]net.IP)
s.exact = make(map[string]*recordSet)
s.wildcards = make(map[string]*recordSet)
s.ptrRecords = make(map[string]string)
}
@@ -494,4 +494,4 @@ func IPToReverseDNS(ip net.IP) string {
}
return ""
}
}

View File

@@ -170,38 +170,47 @@ func TestDNSRecordStoreWildcard(t *testing.T) {
// Add wildcard records
wildcardIP := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.autoco.internal", wildcardIP)
err := store.AddRecord("*.autoco.internal", wildcardIP, 0)
if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err)
}
// Add exact record
exactIP := net.ParseIP("10.0.0.2")
err = store.AddRecord("exact.autoco.internal", exactIP)
err = store.AddRecord("exact.autoco.internal", exactIP, 0)
if err != nil {
t.Fatalf("Failed to add exact record: %v", err)
}
// Test exact match takes precedence
ips := store.GetRecords("exact.autoco.internal.", RecordTypeA)
ips, exists := store.GetRecords("exact.autoco.internal.", RecordTypeA)
if !exists {
t.Error("Expected domain to exist")
}
if len(ips) != 1 {
t.Errorf("Expected 1 IP for exact match, got %d", len(ips))
}
if !ips[0].Equal(exactIP) {
if len(ips) > 0 && !ips[0].Equal(exactIP) {
t.Errorf("Expected exact IP %v, got %v", exactIP, ips[0])
}
// Test wildcard match
ips = store.GetRecords("host.autoco.internal.", RecordTypeA)
ips, exists = store.GetRecords("host.autoco.internal.", RecordTypeA)
if !exists {
t.Error("Expected wildcard match to exist")
}
if len(ips) != 1 {
t.Errorf("Expected 1 IP for wildcard match, got %d", len(ips))
}
if !ips[0].Equal(wildcardIP) {
if len(ips) > 0 && !ips[0].Equal(wildcardIP) {
t.Errorf("Expected wildcard IP %v, got %v", wildcardIP, ips[0])
}
// Test non-match (base domain)
ips = store.GetRecords("autoco.internal.", RecordTypeA)
ips, exists = store.GetRecords("autoco.internal.", RecordTypeA)
if exists {
t.Error("Expected base domain to not exist")
}
if len(ips) != 0 {
t.Errorf("Expected 0 IPs for base domain, got %d", len(ips))
}
@@ -212,13 +221,16 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) {
// Add complex wildcard pattern
ip1 := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.host-0?.autoco.internal", ip1)
err := store.AddRecord("*.host-0?.autoco.internal", ip1, 0)
if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err)
}
// Test matching domain
ips := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA)
ips, exists := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA)
if !exists {
t.Error("Expected complex wildcard match to exist")
}
if len(ips) != 1 {
t.Errorf("Expected 1 IP for complex wildcard match, got %d", len(ips))
}
@@ -227,13 +239,19 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) {
}
// Test non-matching domain (missing prefix)
ips = store.GetRecords("host-01.autoco.internal.", RecordTypeA)
ips, exists = store.GetRecords("host-01.autoco.internal.", RecordTypeA)
if exists {
t.Error("Expected domain without prefix to not exist")
}
if len(ips) != 0 {
t.Errorf("Expected 0 IPs for domain without prefix, got %d", len(ips))
}
// Test non-matching domain (wrong ? position)
ips = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA)
ips, exists = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA)
if exists {
t.Error("Expected domain with wrong ? match to not exist")
}
if len(ips) != 0 {
t.Errorf("Expected 0 IPs for domain with wrong ? match, got %d", len(ips))
}
@@ -244,13 +262,16 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) {
// Add wildcard record
ip := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.autoco.internal", ip)
err := store.AddRecord("*.autoco.internal", ip, 0)
if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err)
}
// Verify it exists
ips := store.GetRecords("host.autoco.internal.", RecordTypeA)
ips, exists := store.GetRecords("host.autoco.internal.", RecordTypeA)
if !exists {
t.Error("Expected domain to exist before removal")
}
if len(ips) != 1 {
t.Errorf("Expected 1 IP before removal, got %d", len(ips))
}
@@ -259,7 +280,10 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) {
store.RemoveRecord("*.autoco.internal", nil)
// Verify it's gone
ips = store.GetRecords("host.autoco.internal.", RecordTypeA)
ips, exists = store.GetRecords("host.autoco.internal.", RecordTypeA)
if exists {
t.Error("Expected domain to not exist after removal")
}
if len(ips) != 0 {
t.Errorf("Expected 0 IPs after removal, got %d", len(ips))
}
@@ -273,36 +297,36 @@ func TestDNSRecordStoreMultipleWildcards(t *testing.T) {
ip2 := net.ParseIP("10.0.0.2")
ip3 := net.ParseIP("10.0.0.3")
err := store.AddRecord("*.prod.autoco.internal", ip1)
err := store.AddRecord("*.prod.autoco.internal", ip1, 0)
if err != nil {
t.Fatalf("Failed to add first wildcard: %v", err)
}
err = store.AddRecord("*.dev.autoco.internal", ip2)
err = store.AddRecord("*.dev.autoco.internal", ip2, 0)
if err != nil {
t.Fatalf("Failed to add second wildcard: %v", err)
}
// Add a broader wildcard that matches both
err = store.AddRecord("*.autoco.internal", ip3)
err = store.AddRecord("*.autoco.internal", ip3, 0)
if err != nil {
t.Fatalf("Failed to add third wildcard: %v", err)
}
// Test domain matching only the prod pattern and the broad pattern
ips := store.GetRecords("host.prod.autoco.internal.", RecordTypeA)
ips, _ := store.GetRecords("host.prod.autoco.internal.", RecordTypeA)
if len(ips) != 2 {
t.Errorf("Expected 2 IPs (prod + broad), got %d", len(ips))
}
// Test domain matching only the dev pattern and the broad pattern
ips = store.GetRecords("service.dev.autoco.internal.", RecordTypeA)
ips, _ = store.GetRecords("service.dev.autoco.internal.", RecordTypeA)
if len(ips) != 2 {
t.Errorf("Expected 2 IPs (dev + broad), got %d", len(ips))
}
// Test domain matching only the broad pattern
ips = store.GetRecords("host.test.autoco.internal.", RecordTypeA)
ips, _ = store.GetRecords("host.test.autoco.internal.", RecordTypeA)
if len(ips) != 1 {
t.Errorf("Expected 1 IP (broad only), got %d", len(ips))
}
@@ -313,13 +337,13 @@ func TestDNSRecordStoreIPv6Wildcard(t *testing.T) {
// Add IPv6 wildcard record
ip := net.ParseIP("2001:db8::1")
err := store.AddRecord("*.autoco.internal", ip)
err := store.AddRecord("*.autoco.internal", ip, 0)
if err != nil {
t.Fatalf("Failed to add IPv6 wildcard record: %v", err)
}
// Test wildcard match for IPv6
ips := store.GetRecords("host.autoco.internal.", RecordTypeAAAA)
ips, _ := store.GetRecords("host.autoco.internal.", RecordTypeAAAA)
if len(ips) != 1 {
t.Errorf("Expected 1 IPv6 for wildcard match, got %d", len(ips))
}
@@ -333,7 +357,7 @@ func TestHasRecordWildcard(t *testing.T) {
// Add wildcard record
ip := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.autoco.internal", ip)
err := store.AddRecord("*.autoco.internal", ip, 0)
if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err)
}
@@ -354,7 +378,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
// Add record with mixed case
ip := net.ParseIP("10.0.0.1")
err := store.AddRecord("MyHost.AutoCo.Internal", ip)
err := store.AddRecord("MyHost.AutoCo.Internal", ip, 0)
if err != nil {
t.Fatalf("Failed to add mixed case record: %v", err)
}
@@ -368,7 +392,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
}
for _, domain := range testCases {
ips := store.GetRecords(domain, RecordTypeA)
ips, _ := store.GetRecords(domain, RecordTypeA)
if len(ips) != 1 {
t.Errorf("Expected 1 IP for domain %q, got %d", domain, len(ips))
}
@@ -379,7 +403,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
// Test wildcard with mixed case
wildcardIP := net.ParseIP("10.0.0.2")
err = store.AddRecord("*.Example.Com", wildcardIP)
err = store.AddRecord("*.Example.Com", wildcardIP, 0)
if err != nil {
t.Fatalf("Failed to add mixed case wildcard: %v", err)
}
@@ -392,7 +416,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
}
for _, domain := range wildcardTestCases {
ips := store.GetRecords(domain, RecordTypeA)
ips, _ := store.GetRecords(domain, RecordTypeA)
if len(ips) != 1 {
t.Errorf("Expected 1 IP for wildcard domain %q, got %d", domain, len(ips))
}
@@ -403,7 +427,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
// Test removal with different case
store.RemoveRecord("MYHOST.AUTOCO.INTERNAL", nil)
ips := store.GetRecords("myhost.autoco.internal.", RecordTypeA)
ips, _ := store.GetRecords("myhost.autoco.internal.", RecordTypeA)
if len(ips) != 0 {
t.Errorf("Expected 0 IPs after removal, got %d", len(ips))
}
@@ -665,7 +689,7 @@ func TestClearPTRRecords(t *testing.T) {
store.AddPTRRecord(ip2, "host2.example.com.")
// Add some A records too
store.AddRecord("test.example.com.", net.ParseIP("10.0.0.1"))
store.AddRecord("test.example.com.", net.ParseIP("10.0.0.1"), 0)
// Verify PTR records exist
if !store.HasPTRRecord("1.1.168.192.in-addr.arpa.") {
@@ -695,7 +719,7 @@ func TestAutomaticPTRRecordOnAdd(t *testing.T) {
// Add an A record - should automatically add PTR record
domain := "host.example.com."
ip := net.ParseIP("192.168.1.100")
err := store.AddRecord(domain, ip)
err := store.AddRecord(domain, ip, 0)
if err != nil {
t.Fatalf("Failed to add A record: %v", err)
}
@@ -713,7 +737,7 @@ func TestAutomaticPTRRecordOnAdd(t *testing.T) {
// Add AAAA record - should also automatically add PTR record
domain6 := "ipv6host.example.com."
ip6 := net.ParseIP("2001:db8::1")
err = store.AddRecord(domain6, ip6)
err = store.AddRecord(domain6, ip6, 0)
if err != nil {
t.Fatalf("Failed to add AAAA record: %v", err)
}
@@ -735,7 +759,7 @@ func TestAutomaticPTRRecordOnRemove(t *testing.T) {
// Add an A record (with automatic PTR)
domain := "host.example.com."
ip := net.ParseIP("192.168.1.100")
store.AddRecord(domain, ip)
store.AddRecord(domain, ip, 0)
// Verify PTR exists
reverseDomain := "100.1.168.192.in-addr.arpa."
@@ -752,7 +776,7 @@ func TestAutomaticPTRRecordOnRemove(t *testing.T) {
}
// Verify A record is also gone
ips := store.GetRecords(domain, RecordTypeA)
ips, _ := store.GetRecords(domain, RecordTypeA)
if len(ips) != 0 {
t.Errorf("Expected A record to be removed, got %d records", len(ips))
}
@@ -765,8 +789,8 @@ func TestAutomaticPTRRecordOnRemoveAll(t *testing.T) {
domain := "host.example.com."
ip1 := net.ParseIP("192.168.1.100")
ip2 := net.ParseIP("192.168.1.101")
store.AddRecord(domain, ip1)
store.AddRecord(domain, ip2)
store.AddRecord(domain, ip1, 0)
store.AddRecord(domain, ip2, 0)
// Verify both PTR records exist
reverseDomain1 := "100.1.168.192.in-addr.arpa."
@@ -796,7 +820,7 @@ func TestNoPTRForWildcardRecords(t *testing.T) {
// Add wildcard record - should NOT create PTR record
domain := "*.example.com."
ip := net.ParseIP("192.168.1.100")
err := store.AddRecord(domain, ip)
err := store.AddRecord(domain, ip, 0)
if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err)
}
@@ -820,7 +844,7 @@ func TestPTRRecordOverwrite(t *testing.T) {
// Add first domain with IP
domain1 := "host1.example.com."
ip := net.ParseIP("192.168.1.100")
store.AddRecord(domain1, ip)
store.AddRecord(domain1, ip, 0)
// Verify PTR points to first domain
reverseDomain := "100.1.168.192.in-addr.arpa."
@@ -834,7 +858,7 @@ func TestPTRRecordOverwrite(t *testing.T) {
// Add second domain with same IP - should overwrite PTR
domain2 := "host2.example.com."
store.AddRecord(domain2, ip)
store.AddRecord(domain2, ip, 0)
// Verify PTR now points to second domain (last one added)
result, ok = store.GetPTRRecord(reverseDomain)

14
go.mod
View File

@@ -1,14 +1,14 @@
module github.com/fosrl/olm
go 1.25
go 1.25.0
require (
github.com/Microsoft/go-winio v0.6.2
github.com/fosrl/newt v1.9.0
github.com/fosrl/newt v1.10.3
github.com/godbus/dbus/v5 v5.2.2
github.com/gorilla/websocket v1.5.3
github.com/miekg/dns v1.1.70
golang.org/x/sys v0.40.0
golang.org/x/sys v0.41.0
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c
@@ -20,13 +20,13 @@ require (
github.com/google/go-cmp v0.7.0 // indirect
github.com/vishvananda/netlink v1.3.1 // indirect
github.com/vishvananda/netns v0.0.5 // indirect
golang.org/x/crypto v0.46.0 // indirect
golang.org/x/crypto v0.48.0 // indirect
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect
golang.org/x/mod v0.31.0 // indirect
golang.org/x/net v0.48.0 // indirect
golang.org/x/mod v0.32.0 // indirect
golang.org/x/net v0.51.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/time v0.12.0 // indirect
golang.org/x/tools v0.40.0 // indirect
golang.org/x/tools v0.41.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
)

24
go.sum
View File

@@ -1,7 +1,7 @@
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/fosrl/newt v1.9.0 h1:66eJMo6fA+YcBTbddxTfNJXNQo1WWKzmn6zPRP5kSDE=
github.com/fosrl/newt v1.9.0/go.mod h1:d1+yYMnKqg4oLqAM9zdbjthjj2FQEVouiACjqU468ck=
github.com/fosrl/newt v1.10.3 h1:JO9gFK9LP/w2EeDIn4wU+jKggAFPo06hX5hxFSETqcw=
github.com/fosrl/newt v1.10.3/go.mod h1:iYuuCAG7iabheiogMOX87r61uQN31S39nKxMKRuLS+s=
github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ=
github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
@@ -16,24 +16,24 @@ github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0=
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0=
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A=

View File

@@ -7,6 +7,7 @@ import (
"runtime"
"strconv"
"strings"
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/network"
@@ -168,20 +169,25 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) {
SharedBind: o.sharedBind,
WSClient: o.websocket,
APIServer: o.apiServer,
PublicDNS: o.tunnelConfig.PublicDNS,
})
for i := range wgData.Sites {
site := wgData.Sites[i]
var siteEndpoint string
// here we are going to take the relay endpoint if it exists which means we requested a relay for this peer
if site.RelayEndpoint != "" {
siteEndpoint = site.RelayEndpoint
} else {
siteEndpoint = site.Endpoint
if site.PublicKey != "" {
var siteEndpoint string
// here we are going to take the relay endpoint if it exists which means we requested a relay for this peer
if site.RelayEndpoint != "" {
siteEndpoint = site.RelayEndpoint
} else {
siteEndpoint = site.Endpoint
}
o.apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false)
}
o.apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false)
// we still call this to add the aliases for jit lookup but we just do that then pass inside. need to skip the above so we dont add to the api
if err := o.peerManager.AddPeer(site); err != nil {
logger.Error("Failed to add peer: %v", err)
return
@@ -196,6 +202,36 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) {
logger.Error("Failed to start DNS proxy: %v", err)
}
// Register JIT handler: when the DNS proxy resolves a local record, check whether
// the owning site is already connected and, if not, initiate a JIT connection.
o.dnsProxy.SetJITHandler(func(siteId int) {
if o.peerManager == nil || o.websocket == nil {
return
}
// Site already has an active peer connection - nothing to do.
if _, exists := o.peerManager.GetPeer(siteId); exists {
return
}
o.peerSendMu.Lock()
defer o.peerSendMu.Unlock()
// A JIT request for this site is already in-flight - avoid duplicate sends.
if _, pending := o.jitPendingSites[siteId]; pending {
return
}
chainId := generateChainId()
logger.Info("DNS-triggered JIT connect for site %d (chainId=%s)", siteId, chainId)
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/init", map[string]interface{}{
"siteId": siteId,
"chainId": chainId,
}, 2*time.Second, 10)
o.stopPeerInits[chainId] = stopFunc
o.jitPendingSites[siteId] = chainId
})
if o.tunnelConfig.OverrideDNS {
// Set up DNS override to use our DNS proxy
if err := dnsOverride.SetupDNSOverride(o.tunnelConfig.InterfaceName, o.dnsProxy.GetProxyIP()); err != nil {
@@ -273,12 +309,12 @@ func (o *Olm) handleTerminate(msg websocket.WSMessage) {
logger.Error("Error unmarshaling terminate error data: %v", err)
} else {
logger.Info("Terminate reason (code: %s): %s", errorData.Code, errorData.Message)
if errorData.Code == "TERMINATED_INACTIVITY" {
logger.Info("Ignoring...")
return
}
// Set the olm error in the API server so it can be exposed via status
o.apiServer.SetOlmError(errorData.Code, errorData.Message)
}

View File

@@ -2,6 +2,7 @@ package olm
import (
"encoding/json"
"fmt"
"time"
"github.com/fosrl/newt/holepunch"
@@ -220,6 +221,7 @@ func (o *Olm) handleSync(msg websocket.WSMessage) {
logger.Info("Sync: Adding new peer for site %d", siteId)
o.holePunchManager.TriggerHolePunch()
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
// // TODO: do we need to send the message to the cloud to add the peer that way?
// if err := o.peerManager.AddPeer(expectedSite); err != nil {
@@ -230,9 +232,17 @@ func (o *Olm) handleSync(msg websocket.WSMessage) {
// add the peer via the server
// this is important because newt needs to get triggered as well to add the peer once the hp is complete
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
"siteId": expectedSite.SiteId,
}, 1*time.Second, 10)
chainId := fmt.Sprintf("sync-%d", expectedSite.SiteId)
o.peerSendMu.Lock()
if stop, ok := o.stopPeerSends[chainId]; ok {
stop()
}
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
"siteId": expectedSite.SiteId,
"chainId": chainId,
}, 2*time.Second, 10)
o.stopPeerSends[chainId] = stopFunc
o.peerSendMu.Unlock()
} else {
// Existing peer - check if update is needed

View File

@@ -2,6 +2,8 @@ package olm
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"net"
"net/http"
@@ -31,7 +33,7 @@ type Olm struct {
privateKey wgtypes.Key
logFile *os.File
registered bool
registered bool
tunnelRunning bool
uapiListener net.Listener
@@ -65,7 +67,10 @@ type Olm struct {
stopRegister func()
updateRegister func(newData any)
stopPeerSend func()
stopPeerSends map[string]func()
stopPeerInits map[string]func()
jitPendingSites map[int]string // siteId -> chainId for in-flight JIT requests
peerSendMu sync.Mutex
// WaitGroup to track tunnel lifecycle
tunnelWg sync.WaitGroup
@@ -111,11 +116,18 @@ func (o *Olm) initTunnelInfo(clientID string) error {
logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount())
// Create the holepunch manager
o.holePunchManager = holepunch.NewManager(sharedBind, clientID, "olm", privateKey.PublicKey().String())
o.holePunchManager = holepunch.NewManager(sharedBind, clientID, "olm", privateKey.PublicKey().String(), o.tunnelConfig.PublicDNS)
return nil
}
// generateChainId generates a random chain ID for tracking peer sender lifecycles.
func generateChainId() string {
b := make([]byte, 8)
_, _ = rand.Read(b)
return hex.EncodeToString(b)
}
func Init(ctx context.Context, config OlmConfig) (*Olm, error) {
logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel))
@@ -166,10 +178,13 @@ func Init(ctx context.Context, config OlmConfig) (*Olm, error) {
apiServer.SetAgent(config.Agent)
newOlm := &Olm{
logFile: logFile,
olmCtx: ctx,
apiServer: apiServer,
olmConfig: config,
logFile: logFile,
olmCtx: ctx,
apiServer: apiServer,
olmConfig: config,
stopPeerSends: make(map[string]func()),
stopPeerInits: make(map[string]func()),
jitPendingSites: make(map[int]string),
}
newOlm.registerAPICallbacks()
@@ -222,7 +237,7 @@ func (o *Olm) registerAPICallbacks() {
tunnelConfig.MTU = 1420
}
if req.DNS == "" {
tunnelConfig.DNS = "9.9.9.9"
tunnelConfig.DNS = "8.8.8.8"
}
// DNSProxyIP has no default - it must be provided if DNS proxy is desired
// UpstreamDNS defaults to 8.8.8.8 if not provided
@@ -284,6 +299,21 @@ func (o *Olm) registerAPICallbacks() {
logger.Info("Processing power mode change request via API: mode=%s", req.Mode)
return o.SetPowerMode(req.Mode)
},
func(req api.JITConnectionRequest) error {
logger.Info("Processing JIT connect request via API: site=%s resource=%s", req.Site, req.Resource)
chainId := generateChainId()
o.peerSendMu.Lock()
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/init", map[string]interface{}{
"siteId": req.Site,
"resourceId": req.Resource,
"chainId": chainId,
}, 2*time.Second, 10)
o.stopPeerInits[chainId] = stopFunc
o.peerSendMu.Unlock()
return nil
},
)
}
@@ -292,16 +322,23 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
logger.Info("Tunnel already running")
return
}
// debug print out the whole config
logger.Debug("Starting tunnel with config: %+v", config)
o.tunnelRunning = true // Also set it here in case it is called externally
o.tunnelConfig = config
// TODO: we are hardcoding this for now but we should really pull it from the current config of the system
if o.tunnelConfig.DNS != "" {
o.tunnelConfig.PublicDNS = []string{o.tunnelConfig.DNS + ":53"}
} else {
o.tunnelConfig.PublicDNS = []string{"8.8.8.8:53"}
}
// Reset terminated status when tunnel starts
o.apiServer.SetTerminated(false)
fingerprint := config.InitialFingerprint
if fingerprint == nil {
fingerprint = make(map[string]any)
@@ -313,7 +350,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
}
o.SetFingerprint(fingerprint)
o.SetPostures(postures)
o.SetPostures(postures)
// Create a cancellable context for this tunnel process
tunnelCtx, cancel := context.WithCancel(o.olmCtx)
@@ -378,6 +415,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
// Handler for peer handshake - adds exit node to holepunch rotation and notifies server
o.websocket.RegisterHandler("olm/wg/peer/holepunch/site/add", o.handleWgPeerHolepunchAddSite)
o.websocket.RegisterHandler("olm/wg/peer/chain/cancel", o.handleCancelChain)
o.websocket.RegisterHandler("olm/sync", o.handleSync)
o.websocket.OnConnect(func() error {
@@ -387,7 +425,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
if o.registered {
o.websocket.StartPingMonitor()
logger.Debug("Already registered, skipping registration")
return nil
}
@@ -420,7 +458,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
"userToken": userToken,
"fingerprint": o.fingerprint,
"postures": o.postures,
}, 1*time.Second, 10)
}, 2*time.Second, 10)
// Invoke onRegistered callback if configured
if o.olmConfig.OnRegistered != nil {
@@ -517,6 +555,23 @@ func (o *Olm) Close() {
o.stopRegister = nil
}
// Stop all pending peer init and send senders before closing websocket
o.peerSendMu.Lock()
for _, stop := range o.stopPeerInits {
if stop != nil {
stop()
}
}
o.stopPeerInits = make(map[string]func())
for _, stop := range o.stopPeerSends {
if stop != nil {
stop()
}
}
o.stopPeerSends = make(map[string]func())
o.jitPendingSites = make(map[int]string)
o.peerSendMu.Unlock()
// send a disconnect message to the cloud to show disconnected
if o.websocket != nil {
o.websocket.SendMessage("olm/disconnecting", map[string]any{})

View File

@@ -20,31 +20,43 @@ func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) {
return
}
if o.stopPeerSend != nil {
o.stopPeerSend()
o.stopPeerSend = nil
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var siteConfig peers.SiteConfig
if err := json.Unmarshal(jsonData, &siteConfig); err != nil {
var siteConfigMsg struct {
peers.SiteConfig
ChainId string `json:"chainId"`
}
if err := json.Unmarshal(jsonData, &siteConfigMsg); err != nil {
logger.Error("Error unmarshaling add data: %v", err)
return
}
if siteConfigMsg.ChainId != "" {
o.peerSendMu.Lock()
if stop, ok := o.stopPeerSends[siteConfigMsg.ChainId]; ok {
stop()
delete(o.stopPeerSends, siteConfigMsg.ChainId)
}
o.peerSendMu.Unlock()
}
if siteConfigMsg.PublicKey == "" {
logger.Warn("Skipping add-peer for site %d (%s): no public key available (site may not be connected)", siteConfigMsg.SiteId, siteConfigMsg.Name)
return
}
_ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it
if err := o.peerManager.AddPeer(siteConfig); err != nil {
if err := o.peerManager.AddPeer(siteConfigMsg.SiteConfig); err != nil {
logger.Error("Failed to add peer: %v", err)
return
}
logger.Info("Successfully added peer for site %d", siteConfig.SiteId)
logger.Info("Successfully added peer for site %d", siteConfigMsg.SiteId)
}
func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) {
@@ -164,13 +176,21 @@ func (o *Olm) handleWgPeerRelay(msg websocket.WSMessage) {
return
}
var relayData peers.RelayPeerData
var relayData struct {
peers.RelayPeerData
ChainId string `json:"chainId"`
}
if err := json.Unmarshal(jsonData, &relayData); err != nil {
logger.Error("Error unmarshaling relay data: %v", err)
return
}
primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint)
if monitor := o.peerManager.GetPeerMonitor(); monitor != nil {
monitor.CancelRelaySend(relayData.ChainId)
}
primaryRelay, err := util.ResolveDomainUpstream(relayData.RelayEndpoint, o.tunnelConfig.PublicDNS)
if err != nil {
logger.Error("Failed to resolve primary relay endpoint: %v", err)
return
@@ -197,13 +217,21 @@ func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) {
return
}
var relayData peers.UnRelayPeerData
var relayData struct {
peers.UnRelayPeerData
ChainId string `json:"chainId"`
}
if err := json.Unmarshal(jsonData, &relayData); err != nil {
logger.Error("Error unmarshaling relay data: %v", err)
return
}
primaryRelay, err := util.ResolveDomain(relayData.Endpoint)
if monitor := o.peerManager.GetPeerMonitor(); monitor != nil {
monitor.CancelRelaySend(relayData.ChainId)
}
primaryRelay, err := util.ResolveDomainUpstream(relayData.Endpoint, o.tunnelConfig.PublicDNS)
if err != nil {
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
}
@@ -230,7 +258,8 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
}
var handshakeData struct {
SiteId int `json:"siteId"`
SiteId int `json:"siteId"`
ChainId string `json:"chainId"`
ExitNode struct {
PublicKey string `json:"publicKey"`
Endpoint string `json:"endpoint"`
@@ -243,6 +272,19 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
return
}
// Stop the peer init sender for this chain, if any
if handshakeData.ChainId != "" {
o.peerSendMu.Lock()
if stop, ok := o.stopPeerInits[handshakeData.ChainId]; ok {
stop()
delete(o.stopPeerInits, handshakeData.ChainId)
}
// If this chain was initiated by a DNS-triggered JIT request, clear the
// pending entry so the site can be re-triggered if needed in the future.
delete(o.jitPendingSites, handshakeData.SiteId)
o.peerSendMu.Unlock()
}
// Get existing peer from PeerManager
_, exists := o.peerManager.GetPeer(handshakeData.SiteId)
if exists {
@@ -273,10 +315,72 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
// Send handshake acknowledgment back to server with retry
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
"siteId": handshakeData.SiteId,
}, 1*time.Second, 10)
// Send handshake acknowledgment back to server with retry, keyed by chainId
chainId := handshakeData.ChainId
if chainId == "" {
chainId = generateChainId()
}
o.peerSendMu.Lock()
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
"siteId": handshakeData.SiteId,
"chainId": chainId,
}, 2*time.Second, 10)
o.stopPeerSends[chainId] = stopFunc
o.peerSendMu.Unlock()
logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint)
}
func (o *Olm) handleCancelChain(msg websocket.WSMessage) {
logger.Debug("Received cancel-chain message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling cancel-chain data: %v", err)
return
}
var cancelData struct {
ChainId string `json:"chainId"`
}
if err := json.Unmarshal(jsonData, &cancelData); err != nil {
logger.Error("Error unmarshaling cancel-chain data: %v", err)
return
}
if cancelData.ChainId == "" {
logger.Warn("Received cancel-chain message with no chainId")
return
}
o.peerSendMu.Lock()
defer o.peerSendMu.Unlock()
found := false
if stop, ok := o.stopPeerInits[cancelData.ChainId]; ok {
stop()
delete(o.stopPeerInits, cancelData.ChainId)
found = true
}
// If this chain was a DNS-triggered JIT request, clear the pending entry so
// the site can be re-triggered on the next DNS lookup.
for siteId, chainId := range o.jitPendingSites {
if chainId == cancelData.ChainId {
delete(o.jitPendingSites, siteId)
break
}
}
if stop, ok := o.stopPeerSends[cancelData.ChainId]; ok {
stop()
delete(o.stopPeerSends, cancelData.ChainId)
found = true
}
if found {
logger.Info("Cancelled chain %s", cancelData.ChainId)
} else {
logger.Warn("Cancel-chain: no active sender found for chain %s", cancelData.ChainId)
}
}

View File

@@ -61,6 +61,7 @@ type TunnelConfig struct {
MTU int
DNS string
UpstreamDNS []string
PublicDNS []string
InterfaceName string
// Advanced

View File

@@ -32,7 +32,8 @@ type PeerManagerConfig struct {
SharedBind *bind.SharedBind
// WSClient is optional - if nil, relay messages won't be sent
WSClient *websocket.Client
APIServer *api.API
APIServer *api.API
PublicDNS []string
}
type PeerManager struct {
@@ -50,7 +51,8 @@ type PeerManager struct {
// key is the CIDR string, value is a set of siteIds that want this IP
allowedIPClaims map[string]map[int]bool
APIServer *api.API
publicDNS []string
PersistentKeepalive int
}
@@ -65,6 +67,7 @@ func NewPeerManager(config PeerManagerConfig) *PeerManager {
allowedIPOwners: make(map[string]int),
allowedIPClaims: make(map[string]map[int]bool),
APIServer: config.APIServer,
publicDNS: config.PublicDNS,
}
// Create the peer monitor
@@ -74,6 +77,7 @@ func NewPeerManager(config PeerManagerConfig) *PeerManager {
config.LocalIP,
config.SharedBind,
config.APIServer,
config.PublicDNS,
)
return pm
@@ -106,6 +110,19 @@ func (pm *PeerManager) GetAllPeers() []SiteConfig {
func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
pm.mu.Lock()
defer pm.mu.Unlock()
for _, alias := range siteConfig.Aliases {
address := net.ParseIP(alias.AliasAddress)
if address == nil {
continue
}
pm.dnsProxy.AddDNSRecord(alias.Alias, address, siteConfig.SiteId)
}
if siteConfig.PublicKey == "" {
logger.Debug("Skip adding site %d because no pub key", siteConfig.SiteId)
return nil
}
// build the allowed IPs list from the remote subnets and aliases and add them to the peer
allowedIPs := make([]string, 0, len(siteConfig.RemoteSubnets)+len(siteConfig.Aliases))
@@ -129,7 +146,7 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
wgConfig := siteConfig
wgConfig.AllowedIps = ownedIPs
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil {
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive, pm.publicDNS); err != nil {
return err
}
@@ -139,14 +156,7 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
if err := network.AddRoutes(siteConfig.RemoteSubnets, pm.interfaceName); err != nil {
logger.Error("Failed to add routes for remote subnets: %v", err)
}
for _, alias := range siteConfig.Aliases {
address := net.ParseIP(alias.AliasAddress)
if address == nil {
continue
}
pm.dnsProxy.AddDNSRecord(alias.Alias, address)
}
monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0]
monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port
@@ -270,7 +280,7 @@ func (pm *PeerManager) RemovePeer(siteId int) error {
ownedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
wgConfig := promotedPeer
wgConfig.AllowedIps = ownedIPs
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil {
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive, pm.publicDNS); err != nil {
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
}
}
@@ -346,7 +356,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
wgConfig := siteConfig
wgConfig.AllowedIps = ownedIPs
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil {
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive, pm.publicDNS); err != nil {
return err
}
@@ -356,7 +366,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
promotedOwnedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
promotedWgConfig := promotedPeer
promotedWgConfig.AllowedIps = promotedOwnedIPs
if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil {
if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive, pm.publicDNS); err != nil {
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
}
}
@@ -433,7 +443,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
if address == nil {
continue
}
pm.dnsProxy.AddDNSRecord(alias.Alias, address)
pm.dnsProxy.AddDNSRecord(alias.Alias, address, siteConfig.SiteId)
}
pm.peerMonitor.UpdateHolepunchEndpoint(siteConfig.SiteId, siteConfig.Endpoint)
@@ -713,7 +723,7 @@ func (pm *PeerManager) AddAlias(siteId int, alias Alias) error {
address := net.ParseIP(alias.AliasAddress)
if address != nil {
pm.dnsProxy.AddDNSRecord(alias.Alias, address)
pm.dnsProxy.AddDNSRecord(alias.Alias, address, siteId)
}
// Add an allowed IP for the alias

View File

@@ -2,6 +2,8 @@ package monitor
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"net"
"net/netip"
@@ -31,9 +33,14 @@ type PeerMonitor struct {
monitors map[int]*Client
mutex sync.Mutex
running bool
timeout time.Duration
timeout time.Duration
maxAttempts int
wsClient *websocket.Client
publicDNS []string
// Relay sender tracking
relaySends map[string]func()
relaySendMu sync.Mutex
// Netstack fields
middleDev *middleDevice.MiddleDevice
@@ -47,13 +54,13 @@ type PeerMonitor struct {
nsWg sync.WaitGroup
// Holepunch testing fields
sharedBind *bind.SharedBind
holepunchTester *holepunch.HolepunchTester
holepunchTimeout time.Duration
holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing
holepunchStatus map[int]bool // siteID -> connected status
holepunchStopChan chan struct{}
holepunchUpdateChan chan struct{}
sharedBind *bind.SharedBind
holepunchTester *holepunch.HolepunchTester
holepunchTimeout time.Duration
holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing
holepunchStatus map[int]bool // siteID -> connected status
holepunchStopChan chan struct{}
holepunchUpdateChan chan struct{}
// Relay tracking fields
relayedPeers map[int]bool // siteID -> whether the peer is currently relayed
@@ -82,7 +89,13 @@ type PeerMonitor struct {
}
// NewPeerMonitor creates a new peer monitor with the given callback
func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind, apiServer *api.API) *PeerMonitor {
func generateChainId() string {
b := make([]byte, 8)
_, _ = rand.Read(b)
return hex.EncodeToString(b)
}
func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind, apiServer *api.API, publicDNS []string) *PeerMonitor {
ctx, cancel := context.WithCancel(context.Background())
pm := &PeerMonitor{
monitors: make(map[int]*Client),
@@ -91,6 +104,7 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
wsClient: wsClient,
middleDev: middleDev,
localIP: localIP,
publicDNS: publicDNS,
activePorts: make(map[uint16]bool),
nsCtx: ctx,
nsCancel: cancel,
@@ -99,6 +113,7 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
holepunchEndpoints: make(map[int]string),
holepunchStatus: make(map[int]bool),
relayedPeers: make(map[int]bool),
relaySends: make(map[string]func()),
holepunchMaxAttempts: 3, // Trigger relay after 3 consecutive failures
holepunchFailures: make(map[int]int),
// Rapid initial test settings: complete within ~1.5 seconds
@@ -124,7 +139,7 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
// Initialize holepunch tester if sharedBind is available
if sharedBind != nil {
pm.holepunchTester = holepunch.NewHolepunchTester(sharedBind)
pm.holepunchTester = holepunch.NewHolepunchTester(sharedBind, publicDNS)
}
return pm
@@ -396,20 +411,23 @@ func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status Connectio
}
}
// sendRelay sends a relay message to the server
// sendRelay sends a relay message to the server with retry, keyed by chainId
func (pm *PeerMonitor) sendRelay(siteID int) error {
if pm.wsClient == nil {
return fmt.Errorf("websocket client is nil")
}
err := pm.wsClient.SendMessage("olm/wg/relay", map[string]interface{}{
"siteId": siteID,
})
if err != nil {
logger.Error("Failed to send registration message: %v", err)
return err
}
logger.Info("Sent relay message")
chainId := generateChainId()
stopFunc, _ := pm.wsClient.SendMessageInterval("olm/wg/relay", map[string]interface{}{
"siteId": siteID,
"chainId": chainId,
}, 2*time.Second, 10)
pm.relaySendMu.Lock()
pm.relaySends[chainId] = stopFunc
pm.relaySendMu.Unlock()
logger.Info("Sent relay message for site %d (chain %s)", siteID, chainId)
return nil
}
@@ -419,23 +437,52 @@ func (pm *PeerMonitor) RequestRelay(siteID int) error {
return pm.sendRelay(siteID)
}
// sendUnRelay sends an unrelay message to the server
// sendUnRelay sends an unrelay message to the server with retry, keyed by chainId
func (pm *PeerMonitor) sendUnRelay(siteID int) error {
if pm.wsClient == nil {
return fmt.Errorf("websocket client is nil")
}
err := pm.wsClient.SendMessage("olm/wg/unrelay", map[string]interface{}{
"siteId": siteID,
})
if err != nil {
logger.Error("Failed to send registration message: %v", err)
return err
}
logger.Info("Sent unrelay message")
chainId := generateChainId()
stopFunc, _ := pm.wsClient.SendMessageInterval("olm/wg/unrelay", map[string]interface{}{
"siteId": siteID,
"chainId": chainId,
}, 2*time.Second, 10)
pm.relaySendMu.Lock()
pm.relaySends[chainId] = stopFunc
pm.relaySendMu.Unlock()
logger.Info("Sent unrelay message for site %d (chain %s)", siteID, chainId)
return nil
}
// CancelRelaySend stops the interval sender for the given chainId, if one exists.
// If chainId is empty, all active relay senders are stopped.
func (pm *PeerMonitor) CancelRelaySend(chainId string) {
pm.relaySendMu.Lock()
defer pm.relaySendMu.Unlock()
if chainId == "" {
for id, stop := range pm.relaySends {
if stop != nil {
stop()
}
delete(pm.relaySends, id)
}
logger.Info("Cancelled all relay senders")
return
}
if stop, ok := pm.relaySends[chainId]; ok {
stop()
delete(pm.relaySends, chainId)
logger.Info("Cancelled relay sender for chain %s", chainId)
} else {
logger.Warn("CancelRelaySend: no active sender for chain %s", chainId)
}
}
// Stop stops monitoring all peers
func (pm *PeerMonitor) Stop() {
// Stop holepunch monitor first (outside of mutex to avoid deadlock)
@@ -534,7 +581,7 @@ func (pm *PeerMonitor) runHolepunchMonitor() {
pm.holepunchCurrentInterval = pm.holepunchMinInterval
currentInterval := pm.holepunchCurrentInterval
pm.mutex.Unlock()
timer.Reset(currentInterval)
logger.Debug("Holepunch monitor interval updated, reset to %v", currentInterval)
case <-timer.C:
@@ -677,6 +724,16 @@ func (pm *PeerMonitor) Close() {
// Stop holepunch monitor first (outside of mutex to avoid deadlock)
pm.stopHolepunchMonitor()
// Stop all pending relay senders
pm.relaySendMu.Lock()
for chainId, stop := range pm.relaySends {
if stop != nil {
stop()
}
delete(pm.relaySends, chainId)
}
pm.relaySendMu.Unlock()
pm.mutex.Lock()
defer pm.mutex.Unlock()

View File

@@ -11,14 +11,14 @@ import (
)
// ConfigurePeer sets up or updates a peer within the WireGuard device
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool, persistentKeepalive int) error {
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool, persistentKeepalive int, publicDNS []string) error {
var endpoint string
if relay && siteConfig.RelayEndpoint != "" {
endpoint = formatEndpoint(siteConfig.RelayEndpoint)
} else {
endpoint = formatEndpoint(siteConfig.Endpoint)
}
siteHost, err := util.ResolveDomain(endpoint)
siteHost, err := util.ResolveDomainUpstream(endpoint, publicDNS)
if err != nil {
return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err)
}