mirror of
https://github.com/fosrl/olm.git
synced 2026-03-13 06:06:45 +00:00
Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c2b5ef96a4 | ||
|
|
e326da3d3e | ||
|
|
53def4e2f6 | ||
|
|
e85fd9d71e | ||
|
|
98a24960f5 | ||
|
|
e82387d515 | ||
|
|
b3cb3e1c92 | ||
|
|
22cd02ae15 | ||
|
|
3f258d3500 | ||
|
|
e2690bcc03 | ||
|
|
f2d0e6a14c | ||
|
|
ae88766d85 | ||
|
|
9ae49e36d5 | ||
|
|
5ca4825800 | ||
|
|
809dbe77de | ||
|
|
c67c2a60a1 | ||
|
|
051c0fdfd8 | ||
|
|
e7507e0837 | ||
|
|
21b66fbb34 | ||
|
|
9c0e37eddb |
59
api/api.go
59
api/api.go
@@ -78,6 +78,13 @@ type MetadataChangeRequest struct {
|
|||||||
Postures map[string]any `json:"postures"`
|
Postures map[string]any `json:"postures"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// JITConnectionRequest defines the structure for a dynamic Just-In-Time connection request.
|
||||||
|
// Either SiteID or ResourceID must be provided (but not necessarily both).
|
||||||
|
type JITConnectionRequest struct {
|
||||||
|
Site string `json:"site,omitempty"`
|
||||||
|
Resource string `json:"resource,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
// API represents the HTTP server and its state
|
// API represents the HTTP server and its state
|
||||||
type API struct {
|
type API struct {
|
||||||
addr string
|
addr string
|
||||||
@@ -92,6 +99,7 @@ type API struct {
|
|||||||
onExit func() error
|
onExit func() error
|
||||||
onRebind func() error
|
onRebind func() error
|
||||||
onPowerMode func(PowerModeRequest) error
|
onPowerMode func(PowerModeRequest) error
|
||||||
|
onJITConnect func(JITConnectionRequest) error
|
||||||
|
|
||||||
statusMu sync.RWMutex
|
statusMu sync.RWMutex
|
||||||
peerStatuses map[int]*PeerStatus
|
peerStatuses map[int]*PeerStatus
|
||||||
@@ -143,6 +151,7 @@ func (s *API) SetHandlers(
|
|||||||
onExit func() error,
|
onExit func() error,
|
||||||
onRebind func() error,
|
onRebind func() error,
|
||||||
onPowerMode func(PowerModeRequest) error,
|
onPowerMode func(PowerModeRequest) error,
|
||||||
|
onJITConnect func(JITConnectionRequest) error,
|
||||||
) {
|
) {
|
||||||
s.onConnect = onConnect
|
s.onConnect = onConnect
|
||||||
s.onSwitchOrg = onSwitchOrg
|
s.onSwitchOrg = onSwitchOrg
|
||||||
@@ -151,6 +160,7 @@ func (s *API) SetHandlers(
|
|||||||
s.onExit = onExit
|
s.onExit = onExit
|
||||||
s.onRebind = onRebind
|
s.onRebind = onRebind
|
||||||
s.onPowerMode = onPowerMode
|
s.onPowerMode = onPowerMode
|
||||||
|
s.onJITConnect = onJITConnect
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start starts the HTTP server
|
// Start starts the HTTP server
|
||||||
@@ -169,6 +179,7 @@ func (s *API) Start() error {
|
|||||||
mux.HandleFunc("/health", s.handleHealth)
|
mux.HandleFunc("/health", s.handleHealth)
|
||||||
mux.HandleFunc("/rebind", s.handleRebind)
|
mux.HandleFunc("/rebind", s.handleRebind)
|
||||||
mux.HandleFunc("/power-mode", s.handlePowerMode)
|
mux.HandleFunc("/power-mode", s.handlePowerMode)
|
||||||
|
mux.HandleFunc("/jit-connect", s.handleJITConnect)
|
||||||
|
|
||||||
s.server = &http.Server{
|
s.server = &http.Server{
|
||||||
Handler: mux,
|
Handler: mux,
|
||||||
@@ -633,6 +644,54 @@ func (s *API) handleRebind(w http.ResponseWriter, r *http.Request) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleJITConnect handles the /jit-connect endpoint.
|
||||||
|
// It initiates a dynamic Just-In-Time connection to a site identified by either
|
||||||
|
// a site or a resource. Exactly one of the two must be provided.
|
||||||
|
func (s *API) handleJITConnect(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req JITConnectionRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate that exactly one of site or resource is provided
|
||||||
|
if req.Site == "" && req.Resource == "" {
|
||||||
|
http.Error(w, "Missing required field: either site or resource must be provided", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.Site != "" && req.Resource != "" {
|
||||||
|
http.Error(w, "Ambiguous request: provide either site or resource, not both", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Site != "" {
|
||||||
|
logger.Info("Received JIT connection request via API: site=%s", req.Site)
|
||||||
|
} else {
|
||||||
|
logger.Info("Received JIT connection request via API: resource=%s", req.Resource)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.onJITConnect != nil {
|
||||||
|
if err := s.onJITConnect(req); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("JIT connection failed: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
http.Error(w, "JIT connect handler not configured", http.StatusNotImplemented)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusAccepted)
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"status": "JIT connection request accepted",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// handlePowerMode handles the /power-mode endpoint
|
// handlePowerMode handles the /power-mode endpoint
|
||||||
// This allows changing the power mode between "normal" and "low"
|
// This allows changing the power mode between "normal" and "low"
|
||||||
func (s *API) handlePowerMode(w http.ResponseWriter, r *http.Request) {
|
func (s *API) handlePowerMode(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|||||||
@@ -45,6 +45,11 @@ type DNSProxy struct {
|
|||||||
tunnelActivePorts map[uint16]bool
|
tunnelActivePorts map[uint16]bool
|
||||||
tunnelPortsLock sync.Mutex
|
tunnelPortsLock sync.Mutex
|
||||||
|
|
||||||
|
// jitHandler is called when a local record is resolved for a site that may not be
|
||||||
|
// connected yet, giving the caller a chance to initiate a JIT connection.
|
||||||
|
// It is invoked asynchronously so it never blocks DNS resolution.
|
||||||
|
jitHandler func(siteId int)
|
||||||
|
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
@@ -384,6 +389,16 @@ func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clie
|
|||||||
response = p.checkLocalRecords(msg, question)
|
response = p.checkLocalRecords(msg, question)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If a local A/AAAA record was found, notify the JIT handler so that the owning
|
||||||
|
// site can be connected on-demand if it is not yet active.
|
||||||
|
if response != nil && p.jitHandler != nil &&
|
||||||
|
(question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA) {
|
||||||
|
if siteId, ok := p.recordStore.GetSiteIdForDomain(question.Name); ok && siteId != 0 {
|
||||||
|
handler := p.jitHandler
|
||||||
|
go handler(siteId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// If no local records, forward to upstream
|
// If no local records, forward to upstream
|
||||||
if response == nil {
|
if response == nil {
|
||||||
logger.Debug("No local record for %s, forwarding upstream to %v", question.Name, p.upstreamDNS)
|
logger.Debug("No local record for %s, forwarding upstream to %v", question.Name, p.upstreamDNS)
|
||||||
@@ -447,19 +462,20 @@ func (p *DNSProxy) checkLocalRecords(query *dns.Msg, question dns.Question) *dns
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ips := p.recordStore.GetRecords(question.Name, recordType)
|
ips, exists := p.recordStore.GetRecords(question.Name, recordType)
|
||||||
if len(ips) == 0 {
|
if !exists {
|
||||||
|
// Domain not found in local records, forward to upstream
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Debug("Found %d local record(s) for %s", len(ips), question.Name)
|
logger.Debug("Found %d local record(s) for %s", len(ips), question.Name)
|
||||||
|
|
||||||
// Create response message
|
// Create response message (NODATA if no records, otherwise with answers)
|
||||||
response := new(dns.Msg)
|
response := new(dns.Msg)
|
||||||
response.SetReply(query)
|
response.SetReply(query)
|
||||||
response.Authoritative = true
|
response.Authoritative = true
|
||||||
|
|
||||||
// Add answer records
|
// Add answer records (loop is a no-op if ips is empty)
|
||||||
for _, ip := range ips {
|
for _, ip := range ips {
|
||||||
var rr dns.RR
|
var rr dns.RR
|
||||||
if question.Qtype == dns.TypeA {
|
if question.Qtype == dns.TypeA {
|
||||||
@@ -717,11 +733,20 @@ func (p *DNSProxy) runPacketSender() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetJITHandler registers a callback that is invoked whenever a local DNS record is
|
||||||
|
// resolved for an A or AAAA query. The siteId identifies which site owns the record.
|
||||||
|
// The handler is called in its own goroutine so it must be safe to call concurrently.
|
||||||
|
// Pass nil to disable JIT notifications.
|
||||||
|
func (p *DNSProxy) SetJITHandler(handler func(siteId int)) {
|
||||||
|
p.jitHandler = handler
|
||||||
|
}
|
||||||
|
|
||||||
// AddDNSRecord adds a DNS record to the local store
|
// AddDNSRecord adds a DNS record to the local store
|
||||||
// domain should be a domain name (e.g., "example.com" or "example.com.")
|
// domain should be a domain name (e.g., "example.com" or "example.com.")
|
||||||
// ip should be a valid IPv4 or IPv6 address
|
// ip should be a valid IPv4 or IPv6 address
|
||||||
func (p *DNSProxy) AddDNSRecord(domain string, ip net.IP) error {
|
func (p *DNSProxy) AddDNSRecord(domain string, ip net.IP, siteId int) error {
|
||||||
return p.recordStore.AddRecord(domain, ip)
|
logger.Debug("Adding dns record for domain %s with IP %s (siteId=%d)", domain, ip.String(), siteId)
|
||||||
|
return p.recordStore.AddRecord(domain, ip, siteId)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveDNSRecord removes a DNS record from the local store
|
// RemoveDNSRecord removes a DNS record from the local store
|
||||||
@@ -730,8 +755,9 @@ func (p *DNSProxy) RemoveDNSRecord(domain string, ip net.IP) {
|
|||||||
p.recordStore.RemoveRecord(domain, ip)
|
p.recordStore.RemoveRecord(domain, ip)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDNSRecords returns all IP addresses for a domain and record type
|
// GetDNSRecords returns all IP addresses for a domain and record type.
|
||||||
func (p *DNSProxy) GetDNSRecords(domain string, recordType RecordType) []net.IP {
|
// The second return value indicates whether the domain exists.
|
||||||
|
func (p *DNSProxy) GetDNSRecords(domain string, recordType RecordType) ([]net.IP, bool) {
|
||||||
return p.recordStore.GetRecords(domain, recordType)
|
return p.recordStore.GetRecords(domain, recordType)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
178
dns/dns_proxy_test.go
Normal file
178
dns/dns_proxy_test.go
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCheckLocalRecordsNODATAForAAAA(t *testing.T) {
|
||||||
|
proxy := &DNSProxy{
|
||||||
|
recordStore: NewDNSRecordStore(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add an A record for a domain (no AAAA record)
|
||||||
|
ip := net.ParseIP("10.0.0.1")
|
||||||
|
err := proxy.recordStore.AddRecord("myservice.internal", ip, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to add A record: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query AAAA for domain with only A record - should return NODATA
|
||||||
|
query := new(dns.Msg)
|
||||||
|
query.SetQuestion("myservice.internal.", dns.TypeAAAA)
|
||||||
|
response := proxy.checkLocalRecords(query, query.Question[0])
|
||||||
|
|
||||||
|
if response == nil {
|
||||||
|
t.Fatal("Expected NODATA response, got nil (would forward to upstream)")
|
||||||
|
}
|
||||||
|
if response.Rcode != dns.RcodeSuccess {
|
||||||
|
t.Errorf("Expected Rcode NOERROR (0), got %d", response.Rcode)
|
||||||
|
}
|
||||||
|
if len(response.Answer) != 0 {
|
||||||
|
t.Errorf("Expected empty answer section for NODATA, got %d answers", len(response.Answer))
|
||||||
|
}
|
||||||
|
if !response.Authoritative {
|
||||||
|
t.Error("Expected response to be authoritative")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query A for same domain - should return the record
|
||||||
|
query = new(dns.Msg)
|
||||||
|
query.SetQuestion("myservice.internal.", dns.TypeA)
|
||||||
|
response = proxy.checkLocalRecords(query, query.Question[0])
|
||||||
|
|
||||||
|
if response == nil {
|
||||||
|
t.Fatal("Expected response with A record, got nil")
|
||||||
|
}
|
||||||
|
if len(response.Answer) != 1 {
|
||||||
|
t.Fatalf("Expected 1 answer, got %d", len(response.Answer))
|
||||||
|
}
|
||||||
|
aRecord, ok := response.Answer[0].(*dns.A)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Expected A record in answer")
|
||||||
|
}
|
||||||
|
if !aRecord.A.Equal(ip.To4()) {
|
||||||
|
t.Errorf("Expected IP %v, got %v", ip.To4(), aRecord.A)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckLocalRecordsNODATAForA(t *testing.T) {
|
||||||
|
proxy := &DNSProxy{
|
||||||
|
recordStore: NewDNSRecordStore(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add an AAAA record for a domain (no A record)
|
||||||
|
ip := net.ParseIP("2001:db8::1")
|
||||||
|
err := proxy.recordStore.AddRecord("ipv6only.internal", ip, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to add AAAA record: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query A for domain with only AAAA record - should return NODATA
|
||||||
|
query := new(dns.Msg)
|
||||||
|
query.SetQuestion("ipv6only.internal.", dns.TypeA)
|
||||||
|
response := proxy.checkLocalRecords(query, query.Question[0])
|
||||||
|
|
||||||
|
if response == nil {
|
||||||
|
t.Fatal("Expected NODATA response, got nil")
|
||||||
|
}
|
||||||
|
if response.Rcode != dns.RcodeSuccess {
|
||||||
|
t.Errorf("Expected Rcode NOERROR (0), got %d", response.Rcode)
|
||||||
|
}
|
||||||
|
if len(response.Answer) != 0 {
|
||||||
|
t.Errorf("Expected empty answer section, got %d answers", len(response.Answer))
|
||||||
|
}
|
||||||
|
if !response.Authoritative {
|
||||||
|
t.Error("Expected response to be authoritative")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query AAAA for same domain - should return the record
|
||||||
|
query = new(dns.Msg)
|
||||||
|
query.SetQuestion("ipv6only.internal.", dns.TypeAAAA)
|
||||||
|
response = proxy.checkLocalRecords(query, query.Question[0])
|
||||||
|
|
||||||
|
if response == nil {
|
||||||
|
t.Fatal("Expected response with AAAA record, got nil")
|
||||||
|
}
|
||||||
|
if len(response.Answer) != 1 {
|
||||||
|
t.Fatalf("Expected 1 answer, got %d", len(response.Answer))
|
||||||
|
}
|
||||||
|
aaaaRecord, ok := response.Answer[0].(*dns.AAAA)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Expected AAAA record in answer")
|
||||||
|
}
|
||||||
|
if !aaaaRecord.AAAA.Equal(ip) {
|
||||||
|
t.Errorf("Expected IP %v, got %v", ip, aaaaRecord.AAAA)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckLocalRecordsNonExistentDomain(t *testing.T) {
|
||||||
|
proxy := &DNSProxy{
|
||||||
|
recordStore: NewDNSRecordStore(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add a record so the store isn't empty
|
||||||
|
err := proxy.recordStore.AddRecord("exists.internal", net.ParseIP("10.0.0.1"), 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to add record: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query A for non-existent domain - should return nil (forward to upstream)
|
||||||
|
query := new(dns.Msg)
|
||||||
|
query.SetQuestion("unknown.internal.", dns.TypeA)
|
||||||
|
response := proxy.checkLocalRecords(query, query.Question[0])
|
||||||
|
|
||||||
|
if response != nil {
|
||||||
|
t.Error("Expected nil for non-existent domain, got response")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query AAAA for non-existent domain - should also return nil
|
||||||
|
query = new(dns.Msg)
|
||||||
|
query.SetQuestion("unknown.internal.", dns.TypeAAAA)
|
||||||
|
response = proxy.checkLocalRecords(query, query.Question[0])
|
||||||
|
|
||||||
|
if response != nil {
|
||||||
|
t.Error("Expected nil for non-existent domain, got response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckLocalRecordsNODATAWildcard(t *testing.T) {
|
||||||
|
proxy := &DNSProxy{
|
||||||
|
recordStore: NewDNSRecordStore(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add a wildcard A record (no AAAA)
|
||||||
|
ip := net.ParseIP("10.0.0.1")
|
||||||
|
err := proxy.recordStore.AddRecord("*.wildcard.internal", ip, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to add wildcard A record: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query AAAA for wildcard-matched domain - should return NODATA
|
||||||
|
query := new(dns.Msg)
|
||||||
|
query.SetQuestion("host.wildcard.internal.", dns.TypeAAAA)
|
||||||
|
response := proxy.checkLocalRecords(query, query.Question[0])
|
||||||
|
|
||||||
|
if response == nil {
|
||||||
|
t.Fatal("Expected NODATA response for wildcard match, got nil")
|
||||||
|
}
|
||||||
|
if response.Rcode != dns.RcodeSuccess {
|
||||||
|
t.Errorf("Expected Rcode NOERROR (0), got %d", response.Rcode)
|
||||||
|
}
|
||||||
|
if len(response.Answer) != 0 {
|
||||||
|
t.Errorf("Expected empty answer section, got %d answers", len(response.Answer))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query A for wildcard-matched domain - should return the record
|
||||||
|
query = new(dns.Msg)
|
||||||
|
query.SetQuestion("host.wildcard.internal.", dns.TypeA)
|
||||||
|
response = proxy.checkLocalRecords(query, query.Question[0])
|
||||||
|
|
||||||
|
if response == nil {
|
||||||
|
t.Fatal("Expected response with A record, got nil")
|
||||||
|
}
|
||||||
|
if len(response.Answer) != 1 {
|
||||||
|
t.Fatalf("Expected 1 answer, got %d", len(response.Answer))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -18,24 +18,28 @@ const (
|
|||||||
RecordTypePTR RecordType = RecordType(dns.TypePTR)
|
RecordTypePTR RecordType = RecordType(dns.TypePTR)
|
||||||
)
|
)
|
||||||
|
|
||||||
// DNSRecordStore manages local DNS records for A, AAAA, and PTR queries
|
// recordSet holds A and AAAA records for a single domain or wildcard pattern
|
||||||
|
type recordSet struct {
|
||||||
|
A []net.IP
|
||||||
|
AAAA []net.IP
|
||||||
|
SiteId int
|
||||||
|
}
|
||||||
|
|
||||||
|
// DNSRecordStore manages local DNS records for A, AAAA, and PTR queries.
|
||||||
|
// Exact domains are stored in a map; wildcard patterns are in a separate map.
|
||||||
type DNSRecordStore struct {
|
type DNSRecordStore struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
aRecords map[string][]net.IP // domain -> list of IPv4 addresses
|
exact map[string]*recordSet // normalized FQDN -> A/AAAA records
|
||||||
aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses
|
wildcards map[string]*recordSet // wildcard pattern -> A/AAAA records
|
||||||
aWildcards map[string][]net.IP // wildcard pattern -> list of IPv4 addresses
|
ptrRecords map[string]string // IP address string -> domain name
|
||||||
aaaaWildcards map[string][]net.IP // wildcard pattern -> list of IPv6 addresses
|
|
||||||
ptrRecords map[string]string // IP address string -> domain name
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDNSRecordStore creates a new DNS record store
|
// NewDNSRecordStore creates a new DNS record store
|
||||||
func NewDNSRecordStore() *DNSRecordStore {
|
func NewDNSRecordStore() *DNSRecordStore {
|
||||||
return &DNSRecordStore{
|
return &DNSRecordStore{
|
||||||
aRecords: make(map[string][]net.IP),
|
exact: make(map[string]*recordSet),
|
||||||
aaaaRecords: make(map[string][]net.IP),
|
wildcards: make(map[string]*recordSet),
|
||||||
aWildcards: make(map[string][]net.IP),
|
ptrRecords: make(map[string]string),
|
||||||
aaaaWildcards: make(map[string][]net.IP),
|
|
||||||
ptrRecords: make(map[string]string),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -43,47 +47,57 @@ func NewDNSRecordStore() *DNSRecordStore {
|
|||||||
// domain should be in FQDN format (e.g., "example.com.")
|
// domain should be in FQDN format (e.g., "example.com.")
|
||||||
// domain can contain wildcards: * (0+ chars) and ? (exactly 1 char)
|
// domain can contain wildcards: * (0+ chars) and ? (exactly 1 char)
|
||||||
// ip should be a valid IPv4 or IPv6 address
|
// ip should be a valid IPv4 or IPv6 address
|
||||||
|
// siteId is the site that owns this alias/domain
|
||||||
// Automatically adds a corresponding PTR record for non-wildcard domains
|
// Automatically adds a corresponding PTR record for non-wildcard domains
|
||||||
func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
|
func (s *DNSRecordStore) AddRecord(domain string, ip net.IP, siteId int) error {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
// Ensure domain ends with a dot (FQDN format)
|
|
||||||
if len(domain) == 0 || domain[len(domain)-1] != '.' {
|
if len(domain) == 0 || domain[len(domain)-1] != '.' {
|
||||||
domain = domain + "."
|
domain = domain + "."
|
||||||
}
|
}
|
||||||
|
|
||||||
// Normalize domain to lowercase FQDN
|
|
||||||
domain = strings.ToLower(dns.Fqdn(domain))
|
domain = strings.ToLower(dns.Fqdn(domain))
|
||||||
|
|
||||||
// Check if domain contains wildcards
|
|
||||||
isWildcard := strings.ContainsAny(domain, "*?")
|
isWildcard := strings.ContainsAny(domain, "*?")
|
||||||
|
|
||||||
if ip.To4() != nil {
|
isV4 := ip.To4() != nil
|
||||||
// IPv4 address
|
if !isV4 && ip.To16() == nil {
|
||||||
if isWildcard {
|
|
||||||
s.aWildcards[domain] = append(s.aWildcards[domain], ip)
|
|
||||||
} else {
|
|
||||||
s.aRecords[domain] = append(s.aRecords[domain], ip)
|
|
||||||
// Automatically add PTR record for non-wildcard domains
|
|
||||||
s.ptrRecords[ip.String()] = domain
|
|
||||||
}
|
|
||||||
} else if ip.To16() != nil {
|
|
||||||
// IPv6 address
|
|
||||||
if isWildcard {
|
|
||||||
s.aaaaWildcards[domain] = append(s.aaaaWildcards[domain], ip)
|
|
||||||
} else {
|
|
||||||
s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip)
|
|
||||||
// Automatically add PTR record for non-wildcard domains
|
|
||||||
s.ptrRecords[ip.String()] = domain
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return &net.ParseError{Type: "IP address", Text: ip.String()}
|
return &net.ParseError{Type: "IP address", Text: ip.String()}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Choose the appropriate map based on whether this is a wildcard
|
||||||
|
m := s.exact
|
||||||
|
if isWildcard {
|
||||||
|
m = s.wildcards
|
||||||
|
}
|
||||||
|
|
||||||
|
if m[domain] == nil {
|
||||||
|
m[domain] = &recordSet{SiteId: siteId}
|
||||||
|
}
|
||||||
|
rs := m[domain]
|
||||||
|
if isV4 {
|
||||||
|
for _, existing := range rs.A {
|
||||||
|
if existing.Equal(ip) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rs.A = append(rs.A, ip)
|
||||||
|
} else {
|
||||||
|
for _, existing := range rs.AAAA {
|
||||||
|
if existing.Equal(ip) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rs.AAAA = append(rs.AAAA, ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add PTR record for non-wildcard domains
|
||||||
|
if !isWildcard {
|
||||||
|
s.ptrRecords[ip.String()] = domain
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// AddPTRRecord adds a PTR record mapping an IP address to a domain name
|
// AddPTRRecord adds a PTR record mapping an IP address to a domain name
|
||||||
// ip should be a valid IPv4 or IPv6 address
|
// ip should be a valid IPv4 or IPv6 address
|
||||||
// domain should be in FQDN format (e.g., "example.com.")
|
// domain should be in FQDN format (e.g., "example.com.")
|
||||||
@@ -112,89 +126,62 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
|
|||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
// Ensure domain ends with a dot (FQDN format)
|
|
||||||
if len(domain) == 0 || domain[len(domain)-1] != '.' {
|
if len(domain) == 0 || domain[len(domain)-1] != '.' {
|
||||||
domain = domain + "."
|
domain = domain + "."
|
||||||
}
|
}
|
||||||
|
|
||||||
// Normalize domain to lowercase FQDN
|
|
||||||
domain = strings.ToLower(dns.Fqdn(domain))
|
domain = strings.ToLower(dns.Fqdn(domain))
|
||||||
|
|
||||||
// Check if domain contains wildcards
|
|
||||||
isWildcard := strings.ContainsAny(domain, "*?")
|
isWildcard := strings.ContainsAny(domain, "*?")
|
||||||
|
|
||||||
if ip == nil {
|
// Choose the appropriate map
|
||||||
// Remove all records for this domain
|
m := s.exact
|
||||||
if isWildcard {
|
if isWildcard {
|
||||||
delete(s.aWildcards, domain)
|
m = s.wildcards
|
||||||
delete(s.aaaaWildcards, domain)
|
}
|
||||||
} else {
|
|
||||||
// For non-wildcard domains, remove PTR records for all IPs
|
rs := m[domain]
|
||||||
if ips, ok := s.aRecords[domain]; ok {
|
if rs == nil {
|
||||||
for _, ipAddr := range ips {
|
|
||||||
// Only remove PTR if it points to this domain
|
|
||||||
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
|
|
||||||
delete(s.ptrRecords, ipAddr.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if ips, ok := s.aaaaRecords[domain]; ok {
|
|
||||||
for _, ipAddr := range ips {
|
|
||||||
// Only remove PTR if it points to this domain
|
|
||||||
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
|
|
||||||
delete(s.ptrRecords, ipAddr.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
delete(s.aRecords, domain)
|
|
||||||
delete(s.aaaaRecords, domain)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ip == nil {
|
||||||
|
// Remove all records for this domain
|
||||||
|
if !isWildcard {
|
||||||
|
for _, ipAddr := range rs.A {
|
||||||
|
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
|
||||||
|
delete(s.ptrRecords, ipAddr.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, ipAddr := range rs.AAAA {
|
||||||
|
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
|
||||||
|
delete(s.ptrRecords, ipAddr.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
delete(m, domain)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove specific IP
|
||||||
if ip.To4() != nil {
|
if ip.To4() != nil {
|
||||||
// Remove specific IPv4 address
|
rs.A = removeIP(rs.A, ip)
|
||||||
if isWildcard {
|
if !isWildcard {
|
||||||
if ips, ok := s.aWildcards[domain]; ok {
|
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
|
||||||
s.aWildcards[domain] = removeIP(ips, ip)
|
delete(s.ptrRecords, ip.String())
|
||||||
if len(s.aWildcards[domain]) == 0 {
|
|
||||||
delete(s.aWildcards, domain)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if ips, ok := s.aRecords[domain]; ok {
|
|
||||||
s.aRecords[domain] = removeIP(ips, ip)
|
|
||||||
if len(s.aRecords[domain]) == 0 {
|
|
||||||
delete(s.aRecords, domain)
|
|
||||||
}
|
|
||||||
// Automatically remove PTR record if it points to this domain
|
|
||||||
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
|
|
||||||
delete(s.ptrRecords, ip.String())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if ip.To16() != nil {
|
} else {
|
||||||
// Remove specific IPv6 address
|
rs.AAAA = removeIP(rs.AAAA, ip)
|
||||||
if isWildcard {
|
if !isWildcard {
|
||||||
if ips, ok := s.aaaaWildcards[domain]; ok {
|
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
|
||||||
s.aaaaWildcards[domain] = removeIP(ips, ip)
|
delete(s.ptrRecords, ip.String())
|
||||||
if len(s.aaaaWildcards[domain]) == 0 {
|
|
||||||
delete(s.aaaaWildcards, domain)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if ips, ok := s.aaaaRecords[domain]; ok {
|
|
||||||
s.aaaaRecords[domain] = removeIP(ips, ip)
|
|
||||||
if len(s.aaaaRecords[domain]) == 0 {
|
|
||||||
delete(s.aaaaRecords, domain)
|
|
||||||
}
|
|
||||||
// Automatically remove PTR record if it points to this domain
|
|
||||||
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
|
|
||||||
delete(s.ptrRecords, ip.String())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clean up empty record sets
|
||||||
|
if len(rs.A) == 0 && len(rs.AAAA) == 0 {
|
||||||
|
delete(m, domain)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemovePTRRecord removes a PTR record for an IP address
|
// RemovePTRRecord removes a PTR record for an IP address
|
||||||
@@ -205,61 +192,80 @@ func (s *DNSRecordStore) RemovePTRRecord(ip net.IP) {
|
|||||||
delete(s.ptrRecords, ip.String())
|
delete(s.ptrRecords, ip.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRecords returns all IP addresses for a domain and record type
|
// GetSiteIdForDomain returns the siteId associated with the given domain.
|
||||||
// First checks for exact matches, then checks wildcard patterns
|
// It checks exact matches first, then wildcard patterns.
|
||||||
func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP {
|
// The second return value is false if the domain is not found in local records.
|
||||||
|
func (s *DNSRecordStore) GetSiteIdForDomain(domain string) (int, bool) {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
// Normalize domain to lowercase FQDN
|
|
||||||
domain = strings.ToLower(dns.Fqdn(domain))
|
domain = strings.ToLower(dns.Fqdn(domain))
|
||||||
|
|
||||||
var records []net.IP
|
// Check exact match first
|
||||||
switch recordType {
|
if rs, exists := s.exact[domain]; exists {
|
||||||
case RecordTypeA:
|
return rs.SiteId, true
|
||||||
// Check exact match first
|
}
|
||||||
if ips, ok := s.aRecords[domain]; ok {
|
|
||||||
// Return a copy to prevent external modifications
|
|
||||||
records = make([]net.IP, len(ips))
|
|
||||||
copy(records, ips)
|
|
||||||
return records
|
|
||||||
}
|
|
||||||
// Check wildcard patterns
|
|
||||||
for pattern, ips := range s.aWildcards {
|
|
||||||
if matchWildcard(pattern, domain) {
|
|
||||||
records = append(records, ips...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(records) > 0 {
|
|
||||||
// Return a copy
|
|
||||||
result := make([]net.IP, len(records))
|
|
||||||
copy(result, records)
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
case RecordTypeAAAA:
|
// Check wildcard matches
|
||||||
// Check exact match first
|
for pattern, rs := range s.wildcards {
|
||||||
if ips, ok := s.aaaaRecords[domain]; ok {
|
if matchWildcard(pattern, domain) {
|
||||||
// Return a copy to prevent external modifications
|
return rs.SiteId, true
|
||||||
records = make([]net.IP, len(ips))
|
|
||||||
copy(records, ips)
|
|
||||||
return records
|
|
||||||
}
|
|
||||||
// Check wildcard patterns
|
|
||||||
for pattern, ips := range s.aaaaWildcards {
|
|
||||||
if matchWildcard(pattern, domain) {
|
|
||||||
records = append(records, ips...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(records) > 0 {
|
|
||||||
// Return a copy
|
|
||||||
result := make([]net.IP, len(records))
|
|
||||||
copy(result, records)
|
|
||||||
return result
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return records
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRecords returns all IP addresses for a domain and record type.
|
||||||
|
// The second return value indicates whether the domain exists at all
|
||||||
|
// (true = domain exists, use NODATA if no records; false = NXDOMAIN).
|
||||||
|
func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) ([]net.IP, bool) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
domain = strings.ToLower(dns.Fqdn(domain))
|
||||||
|
|
||||||
|
// Check exact match first
|
||||||
|
if rs, exists := s.exact[domain]; exists {
|
||||||
|
var ips []net.IP
|
||||||
|
if recordType == RecordTypeA {
|
||||||
|
ips = rs.A
|
||||||
|
} else {
|
||||||
|
ips = rs.AAAA
|
||||||
|
}
|
||||||
|
if len(ips) > 0 {
|
||||||
|
out := make([]net.IP, len(ips))
|
||||||
|
copy(out, ips)
|
||||||
|
return out, true
|
||||||
|
}
|
||||||
|
// Domain exists but no records of this type
|
||||||
|
return nil, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check wildcard matches
|
||||||
|
var records []net.IP
|
||||||
|
matched := false
|
||||||
|
for pattern, rs := range s.wildcards {
|
||||||
|
if !matchWildcard(pattern, domain) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
matched = true
|
||||||
|
if recordType == RecordTypeA {
|
||||||
|
records = append(records, rs.A...)
|
||||||
|
} else {
|
||||||
|
records = append(records, rs.AAAA...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !matched {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
if len(records) == 0 {
|
||||||
|
return nil, true
|
||||||
|
}
|
||||||
|
out := make([]net.IP, len(records))
|
||||||
|
copy(out, records)
|
||||||
|
return out, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPTRRecord returns the domain name for a PTR record query
|
// GetPTRRecord returns the domain name for a PTR record query
|
||||||
@@ -288,34 +294,30 @@ func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
|
|||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
// Normalize domain to lowercase FQDN
|
|
||||||
domain = strings.ToLower(dns.Fqdn(domain))
|
domain = strings.ToLower(dns.Fqdn(domain))
|
||||||
|
|
||||||
switch recordType {
|
// Check exact match
|
||||||
case RecordTypeA:
|
if rs, exists := s.exact[domain]; exists {
|
||||||
// Check exact match
|
if recordType == RecordTypeA && len(rs.A) > 0 {
|
||||||
if _, ok := s.aRecords[domain]; ok {
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
// Check wildcard patterns
|
if recordType == RecordTypeAAAA && len(rs.AAAA) > 0 {
|
||||||
for pattern := range s.aWildcards {
|
|
||||||
if matchWildcard(pattern, domain) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case RecordTypeAAAA:
|
|
||||||
// Check exact match
|
|
||||||
if _, ok := s.aaaaRecords[domain]; ok {
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
// Check wildcard patterns
|
|
||||||
for pattern := range s.aaaaWildcards {
|
|
||||||
if matchWildcard(pattern, domain) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check wildcard matches
|
||||||
|
for pattern, rs := range s.wildcards {
|
||||||
|
if !matchWildcard(pattern, domain) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if recordType == RecordTypeA && len(rs.A) > 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if recordType == RecordTypeAAAA && len(rs.AAAA) > 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -339,10 +341,8 @@ func (s *DNSRecordStore) Clear() {
|
|||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
s.aRecords = make(map[string][]net.IP)
|
s.exact = make(map[string]*recordSet)
|
||||||
s.aaaaRecords = make(map[string][]net.IP)
|
s.wildcards = make(map[string]*recordSet)
|
||||||
s.aWildcards = make(map[string][]net.IP)
|
|
||||||
s.aaaaWildcards = make(map[string][]net.IP)
|
|
||||||
s.ptrRecords = make(map[string]string)
|
s.ptrRecords = make(map[string]string)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -494,4 +494,4 @@ func IPToReverseDNS(ip net.IP) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -170,38 +170,47 @@ func TestDNSRecordStoreWildcard(t *testing.T) {
|
|||||||
|
|
||||||
// Add wildcard records
|
// Add wildcard records
|
||||||
wildcardIP := net.ParseIP("10.0.0.1")
|
wildcardIP := net.ParseIP("10.0.0.1")
|
||||||
err := store.AddRecord("*.autoco.internal", wildcardIP)
|
err := store.AddRecord("*.autoco.internal", wildcardIP, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to add wildcard record: %v", err)
|
t.Fatalf("Failed to add wildcard record: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add exact record
|
// Add exact record
|
||||||
exactIP := net.ParseIP("10.0.0.2")
|
exactIP := net.ParseIP("10.0.0.2")
|
||||||
err = store.AddRecord("exact.autoco.internal", exactIP)
|
err = store.AddRecord("exact.autoco.internal", exactIP, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to add exact record: %v", err)
|
t.Fatalf("Failed to add exact record: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test exact match takes precedence
|
// Test exact match takes precedence
|
||||||
ips := store.GetRecords("exact.autoco.internal.", RecordTypeA)
|
ips, exists := store.GetRecords("exact.autoco.internal.", RecordTypeA)
|
||||||
|
if !exists {
|
||||||
|
t.Error("Expected domain to exist")
|
||||||
|
}
|
||||||
if len(ips) != 1 {
|
if len(ips) != 1 {
|
||||||
t.Errorf("Expected 1 IP for exact match, got %d", len(ips))
|
t.Errorf("Expected 1 IP for exact match, got %d", len(ips))
|
||||||
}
|
}
|
||||||
if !ips[0].Equal(exactIP) {
|
if len(ips) > 0 && !ips[0].Equal(exactIP) {
|
||||||
t.Errorf("Expected exact IP %v, got %v", exactIP, ips[0])
|
t.Errorf("Expected exact IP %v, got %v", exactIP, ips[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test wildcard match
|
// Test wildcard match
|
||||||
ips = store.GetRecords("host.autoco.internal.", RecordTypeA)
|
ips, exists = store.GetRecords("host.autoco.internal.", RecordTypeA)
|
||||||
|
if !exists {
|
||||||
|
t.Error("Expected wildcard match to exist")
|
||||||
|
}
|
||||||
if len(ips) != 1 {
|
if len(ips) != 1 {
|
||||||
t.Errorf("Expected 1 IP for wildcard match, got %d", len(ips))
|
t.Errorf("Expected 1 IP for wildcard match, got %d", len(ips))
|
||||||
}
|
}
|
||||||
if !ips[0].Equal(wildcardIP) {
|
if len(ips) > 0 && !ips[0].Equal(wildcardIP) {
|
||||||
t.Errorf("Expected wildcard IP %v, got %v", wildcardIP, ips[0])
|
t.Errorf("Expected wildcard IP %v, got %v", wildcardIP, ips[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test non-match (base domain)
|
// Test non-match (base domain)
|
||||||
ips = store.GetRecords("autoco.internal.", RecordTypeA)
|
ips, exists = store.GetRecords("autoco.internal.", RecordTypeA)
|
||||||
|
if exists {
|
||||||
|
t.Error("Expected base domain to not exist")
|
||||||
|
}
|
||||||
if len(ips) != 0 {
|
if len(ips) != 0 {
|
||||||
t.Errorf("Expected 0 IPs for base domain, got %d", len(ips))
|
t.Errorf("Expected 0 IPs for base domain, got %d", len(ips))
|
||||||
}
|
}
|
||||||
@@ -212,13 +221,16 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) {
|
|||||||
|
|
||||||
// Add complex wildcard pattern
|
// Add complex wildcard pattern
|
||||||
ip1 := net.ParseIP("10.0.0.1")
|
ip1 := net.ParseIP("10.0.0.1")
|
||||||
err := store.AddRecord("*.host-0?.autoco.internal", ip1)
|
err := store.AddRecord("*.host-0?.autoco.internal", ip1, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to add wildcard record: %v", err)
|
t.Fatalf("Failed to add wildcard record: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test matching domain
|
// Test matching domain
|
||||||
ips := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA)
|
ips, exists := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA)
|
||||||
|
if !exists {
|
||||||
|
t.Error("Expected complex wildcard match to exist")
|
||||||
|
}
|
||||||
if len(ips) != 1 {
|
if len(ips) != 1 {
|
||||||
t.Errorf("Expected 1 IP for complex wildcard match, got %d", len(ips))
|
t.Errorf("Expected 1 IP for complex wildcard match, got %d", len(ips))
|
||||||
}
|
}
|
||||||
@@ -227,13 +239,19 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Test non-matching domain (missing prefix)
|
// Test non-matching domain (missing prefix)
|
||||||
ips = store.GetRecords("host-01.autoco.internal.", RecordTypeA)
|
ips, exists = store.GetRecords("host-01.autoco.internal.", RecordTypeA)
|
||||||
|
if exists {
|
||||||
|
t.Error("Expected domain without prefix to not exist")
|
||||||
|
}
|
||||||
if len(ips) != 0 {
|
if len(ips) != 0 {
|
||||||
t.Errorf("Expected 0 IPs for domain without prefix, got %d", len(ips))
|
t.Errorf("Expected 0 IPs for domain without prefix, got %d", len(ips))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test non-matching domain (wrong ? position)
|
// Test non-matching domain (wrong ? position)
|
||||||
ips = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA)
|
ips, exists = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA)
|
||||||
|
if exists {
|
||||||
|
t.Error("Expected domain with wrong ? match to not exist")
|
||||||
|
}
|
||||||
if len(ips) != 0 {
|
if len(ips) != 0 {
|
||||||
t.Errorf("Expected 0 IPs for domain with wrong ? match, got %d", len(ips))
|
t.Errorf("Expected 0 IPs for domain with wrong ? match, got %d", len(ips))
|
||||||
}
|
}
|
||||||
@@ -244,13 +262,16 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) {
|
|||||||
|
|
||||||
// Add wildcard record
|
// Add wildcard record
|
||||||
ip := net.ParseIP("10.0.0.1")
|
ip := net.ParseIP("10.0.0.1")
|
||||||
err := store.AddRecord("*.autoco.internal", ip)
|
err := store.AddRecord("*.autoco.internal", ip, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to add wildcard record: %v", err)
|
t.Fatalf("Failed to add wildcard record: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify it exists
|
// Verify it exists
|
||||||
ips := store.GetRecords("host.autoco.internal.", RecordTypeA)
|
ips, exists := store.GetRecords("host.autoco.internal.", RecordTypeA)
|
||||||
|
if !exists {
|
||||||
|
t.Error("Expected domain to exist before removal")
|
||||||
|
}
|
||||||
if len(ips) != 1 {
|
if len(ips) != 1 {
|
||||||
t.Errorf("Expected 1 IP before removal, got %d", len(ips))
|
t.Errorf("Expected 1 IP before removal, got %d", len(ips))
|
||||||
}
|
}
|
||||||
@@ -259,7 +280,10 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) {
|
|||||||
store.RemoveRecord("*.autoco.internal", nil)
|
store.RemoveRecord("*.autoco.internal", nil)
|
||||||
|
|
||||||
// Verify it's gone
|
// Verify it's gone
|
||||||
ips = store.GetRecords("host.autoco.internal.", RecordTypeA)
|
ips, exists = store.GetRecords("host.autoco.internal.", RecordTypeA)
|
||||||
|
if exists {
|
||||||
|
t.Error("Expected domain to not exist after removal")
|
||||||
|
}
|
||||||
if len(ips) != 0 {
|
if len(ips) != 0 {
|
||||||
t.Errorf("Expected 0 IPs after removal, got %d", len(ips))
|
t.Errorf("Expected 0 IPs after removal, got %d", len(ips))
|
||||||
}
|
}
|
||||||
@@ -273,36 +297,36 @@ func TestDNSRecordStoreMultipleWildcards(t *testing.T) {
|
|||||||
ip2 := net.ParseIP("10.0.0.2")
|
ip2 := net.ParseIP("10.0.0.2")
|
||||||
ip3 := net.ParseIP("10.0.0.3")
|
ip3 := net.ParseIP("10.0.0.3")
|
||||||
|
|
||||||
err := store.AddRecord("*.prod.autoco.internal", ip1)
|
err := store.AddRecord("*.prod.autoco.internal", ip1, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to add first wildcard: %v", err)
|
t.Fatalf("Failed to add first wildcard: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = store.AddRecord("*.dev.autoco.internal", ip2)
|
err = store.AddRecord("*.dev.autoco.internal", ip2, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to add second wildcard: %v", err)
|
t.Fatalf("Failed to add second wildcard: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add a broader wildcard that matches both
|
// Add a broader wildcard that matches both
|
||||||
err = store.AddRecord("*.autoco.internal", ip3)
|
err = store.AddRecord("*.autoco.internal", ip3, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to add third wildcard: %v", err)
|
t.Fatalf("Failed to add third wildcard: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test domain matching only the prod pattern and the broad pattern
|
// Test domain matching only the prod pattern and the broad pattern
|
||||||
ips := store.GetRecords("host.prod.autoco.internal.", RecordTypeA)
|
ips, _ := store.GetRecords("host.prod.autoco.internal.", RecordTypeA)
|
||||||
if len(ips) != 2 {
|
if len(ips) != 2 {
|
||||||
t.Errorf("Expected 2 IPs (prod + broad), got %d", len(ips))
|
t.Errorf("Expected 2 IPs (prod + broad), got %d", len(ips))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test domain matching only the dev pattern and the broad pattern
|
// Test domain matching only the dev pattern and the broad pattern
|
||||||
ips = store.GetRecords("service.dev.autoco.internal.", RecordTypeA)
|
ips, _ = store.GetRecords("service.dev.autoco.internal.", RecordTypeA)
|
||||||
if len(ips) != 2 {
|
if len(ips) != 2 {
|
||||||
t.Errorf("Expected 2 IPs (dev + broad), got %d", len(ips))
|
t.Errorf("Expected 2 IPs (dev + broad), got %d", len(ips))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test domain matching only the broad pattern
|
// Test domain matching only the broad pattern
|
||||||
ips = store.GetRecords("host.test.autoco.internal.", RecordTypeA)
|
ips, _ = store.GetRecords("host.test.autoco.internal.", RecordTypeA)
|
||||||
if len(ips) != 1 {
|
if len(ips) != 1 {
|
||||||
t.Errorf("Expected 1 IP (broad only), got %d", len(ips))
|
t.Errorf("Expected 1 IP (broad only), got %d", len(ips))
|
||||||
}
|
}
|
||||||
@@ -313,13 +337,13 @@ func TestDNSRecordStoreIPv6Wildcard(t *testing.T) {
|
|||||||
|
|
||||||
// Add IPv6 wildcard record
|
// Add IPv6 wildcard record
|
||||||
ip := net.ParseIP("2001:db8::1")
|
ip := net.ParseIP("2001:db8::1")
|
||||||
err := store.AddRecord("*.autoco.internal", ip)
|
err := store.AddRecord("*.autoco.internal", ip, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to add IPv6 wildcard record: %v", err)
|
t.Fatalf("Failed to add IPv6 wildcard record: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test wildcard match for IPv6
|
// Test wildcard match for IPv6
|
||||||
ips := store.GetRecords("host.autoco.internal.", RecordTypeAAAA)
|
ips, _ := store.GetRecords("host.autoco.internal.", RecordTypeAAAA)
|
||||||
if len(ips) != 1 {
|
if len(ips) != 1 {
|
||||||
t.Errorf("Expected 1 IPv6 for wildcard match, got %d", len(ips))
|
t.Errorf("Expected 1 IPv6 for wildcard match, got %d", len(ips))
|
||||||
}
|
}
|
||||||
@@ -333,7 +357,7 @@ func TestHasRecordWildcard(t *testing.T) {
|
|||||||
|
|
||||||
// Add wildcard record
|
// Add wildcard record
|
||||||
ip := net.ParseIP("10.0.0.1")
|
ip := net.ParseIP("10.0.0.1")
|
||||||
err := store.AddRecord("*.autoco.internal", ip)
|
err := store.AddRecord("*.autoco.internal", ip, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to add wildcard record: %v", err)
|
t.Fatalf("Failed to add wildcard record: %v", err)
|
||||||
}
|
}
|
||||||
@@ -354,7 +378,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
|
|||||||
|
|
||||||
// Add record with mixed case
|
// Add record with mixed case
|
||||||
ip := net.ParseIP("10.0.0.1")
|
ip := net.ParseIP("10.0.0.1")
|
||||||
err := store.AddRecord("MyHost.AutoCo.Internal", ip)
|
err := store.AddRecord("MyHost.AutoCo.Internal", ip, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to add mixed case record: %v", err)
|
t.Fatalf("Failed to add mixed case record: %v", err)
|
||||||
}
|
}
|
||||||
@@ -368,7 +392,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, domain := range testCases {
|
for _, domain := range testCases {
|
||||||
ips := store.GetRecords(domain, RecordTypeA)
|
ips, _ := store.GetRecords(domain, RecordTypeA)
|
||||||
if len(ips) != 1 {
|
if len(ips) != 1 {
|
||||||
t.Errorf("Expected 1 IP for domain %q, got %d", domain, len(ips))
|
t.Errorf("Expected 1 IP for domain %q, got %d", domain, len(ips))
|
||||||
}
|
}
|
||||||
@@ -379,7 +403,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
|
|||||||
|
|
||||||
// Test wildcard with mixed case
|
// Test wildcard with mixed case
|
||||||
wildcardIP := net.ParseIP("10.0.0.2")
|
wildcardIP := net.ParseIP("10.0.0.2")
|
||||||
err = store.AddRecord("*.Example.Com", wildcardIP)
|
err = store.AddRecord("*.Example.Com", wildcardIP, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to add mixed case wildcard: %v", err)
|
t.Fatalf("Failed to add mixed case wildcard: %v", err)
|
||||||
}
|
}
|
||||||
@@ -392,7 +416,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, domain := range wildcardTestCases {
|
for _, domain := range wildcardTestCases {
|
||||||
ips := store.GetRecords(domain, RecordTypeA)
|
ips, _ := store.GetRecords(domain, RecordTypeA)
|
||||||
if len(ips) != 1 {
|
if len(ips) != 1 {
|
||||||
t.Errorf("Expected 1 IP for wildcard domain %q, got %d", domain, len(ips))
|
t.Errorf("Expected 1 IP for wildcard domain %q, got %d", domain, len(ips))
|
||||||
}
|
}
|
||||||
@@ -403,7 +427,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
|
|||||||
|
|
||||||
// Test removal with different case
|
// Test removal with different case
|
||||||
store.RemoveRecord("MYHOST.AUTOCO.INTERNAL", nil)
|
store.RemoveRecord("MYHOST.AUTOCO.INTERNAL", nil)
|
||||||
ips := store.GetRecords("myhost.autoco.internal.", RecordTypeA)
|
ips, _ := store.GetRecords("myhost.autoco.internal.", RecordTypeA)
|
||||||
if len(ips) != 0 {
|
if len(ips) != 0 {
|
||||||
t.Errorf("Expected 0 IPs after removal, got %d", len(ips))
|
t.Errorf("Expected 0 IPs after removal, got %d", len(ips))
|
||||||
}
|
}
|
||||||
@@ -665,7 +689,7 @@ func TestClearPTRRecords(t *testing.T) {
|
|||||||
store.AddPTRRecord(ip2, "host2.example.com.")
|
store.AddPTRRecord(ip2, "host2.example.com.")
|
||||||
|
|
||||||
// Add some A records too
|
// Add some A records too
|
||||||
store.AddRecord("test.example.com.", net.ParseIP("10.0.0.1"))
|
store.AddRecord("test.example.com.", net.ParseIP("10.0.0.1"), 0)
|
||||||
|
|
||||||
// Verify PTR records exist
|
// Verify PTR records exist
|
||||||
if !store.HasPTRRecord("1.1.168.192.in-addr.arpa.") {
|
if !store.HasPTRRecord("1.1.168.192.in-addr.arpa.") {
|
||||||
@@ -695,7 +719,7 @@ func TestAutomaticPTRRecordOnAdd(t *testing.T) {
|
|||||||
// Add an A record - should automatically add PTR record
|
// Add an A record - should automatically add PTR record
|
||||||
domain := "host.example.com."
|
domain := "host.example.com."
|
||||||
ip := net.ParseIP("192.168.1.100")
|
ip := net.ParseIP("192.168.1.100")
|
||||||
err := store.AddRecord(domain, ip)
|
err := store.AddRecord(domain, ip, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to add A record: %v", err)
|
t.Fatalf("Failed to add A record: %v", err)
|
||||||
}
|
}
|
||||||
@@ -713,7 +737,7 @@ func TestAutomaticPTRRecordOnAdd(t *testing.T) {
|
|||||||
// Add AAAA record - should also automatically add PTR record
|
// Add AAAA record - should also automatically add PTR record
|
||||||
domain6 := "ipv6host.example.com."
|
domain6 := "ipv6host.example.com."
|
||||||
ip6 := net.ParseIP("2001:db8::1")
|
ip6 := net.ParseIP("2001:db8::1")
|
||||||
err = store.AddRecord(domain6, ip6)
|
err = store.AddRecord(domain6, ip6, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to add AAAA record: %v", err)
|
t.Fatalf("Failed to add AAAA record: %v", err)
|
||||||
}
|
}
|
||||||
@@ -735,7 +759,7 @@ func TestAutomaticPTRRecordOnRemove(t *testing.T) {
|
|||||||
// Add an A record (with automatic PTR)
|
// Add an A record (with automatic PTR)
|
||||||
domain := "host.example.com."
|
domain := "host.example.com."
|
||||||
ip := net.ParseIP("192.168.1.100")
|
ip := net.ParseIP("192.168.1.100")
|
||||||
store.AddRecord(domain, ip)
|
store.AddRecord(domain, ip, 0)
|
||||||
|
|
||||||
// Verify PTR exists
|
// Verify PTR exists
|
||||||
reverseDomain := "100.1.168.192.in-addr.arpa."
|
reverseDomain := "100.1.168.192.in-addr.arpa."
|
||||||
@@ -752,7 +776,7 @@ func TestAutomaticPTRRecordOnRemove(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Verify A record is also gone
|
// Verify A record is also gone
|
||||||
ips := store.GetRecords(domain, RecordTypeA)
|
ips, _ := store.GetRecords(domain, RecordTypeA)
|
||||||
if len(ips) != 0 {
|
if len(ips) != 0 {
|
||||||
t.Errorf("Expected A record to be removed, got %d records", len(ips))
|
t.Errorf("Expected A record to be removed, got %d records", len(ips))
|
||||||
}
|
}
|
||||||
@@ -765,8 +789,8 @@ func TestAutomaticPTRRecordOnRemoveAll(t *testing.T) {
|
|||||||
domain := "host.example.com."
|
domain := "host.example.com."
|
||||||
ip1 := net.ParseIP("192.168.1.100")
|
ip1 := net.ParseIP("192.168.1.100")
|
||||||
ip2 := net.ParseIP("192.168.1.101")
|
ip2 := net.ParseIP("192.168.1.101")
|
||||||
store.AddRecord(domain, ip1)
|
store.AddRecord(domain, ip1, 0)
|
||||||
store.AddRecord(domain, ip2)
|
store.AddRecord(domain, ip2, 0)
|
||||||
|
|
||||||
// Verify both PTR records exist
|
// Verify both PTR records exist
|
||||||
reverseDomain1 := "100.1.168.192.in-addr.arpa."
|
reverseDomain1 := "100.1.168.192.in-addr.arpa."
|
||||||
@@ -796,7 +820,7 @@ func TestNoPTRForWildcardRecords(t *testing.T) {
|
|||||||
// Add wildcard record - should NOT create PTR record
|
// Add wildcard record - should NOT create PTR record
|
||||||
domain := "*.example.com."
|
domain := "*.example.com."
|
||||||
ip := net.ParseIP("192.168.1.100")
|
ip := net.ParseIP("192.168.1.100")
|
||||||
err := store.AddRecord(domain, ip)
|
err := store.AddRecord(domain, ip, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to add wildcard record: %v", err)
|
t.Fatalf("Failed to add wildcard record: %v", err)
|
||||||
}
|
}
|
||||||
@@ -820,7 +844,7 @@ func TestPTRRecordOverwrite(t *testing.T) {
|
|||||||
// Add first domain with IP
|
// Add first domain with IP
|
||||||
domain1 := "host1.example.com."
|
domain1 := "host1.example.com."
|
||||||
ip := net.ParseIP("192.168.1.100")
|
ip := net.ParseIP("192.168.1.100")
|
||||||
store.AddRecord(domain1, ip)
|
store.AddRecord(domain1, ip, 0)
|
||||||
|
|
||||||
// Verify PTR points to first domain
|
// Verify PTR points to first domain
|
||||||
reverseDomain := "100.1.168.192.in-addr.arpa."
|
reverseDomain := "100.1.168.192.in-addr.arpa."
|
||||||
@@ -834,7 +858,7 @@ func TestPTRRecordOverwrite(t *testing.T) {
|
|||||||
|
|
||||||
// Add second domain with same IP - should overwrite PTR
|
// Add second domain with same IP - should overwrite PTR
|
||||||
domain2 := "host2.example.com."
|
domain2 := "host2.example.com."
|
||||||
store.AddRecord(domain2, ip)
|
store.AddRecord(domain2, ip, 0)
|
||||||
|
|
||||||
// Verify PTR now points to second domain (last one added)
|
// Verify PTR now points to second domain (last one added)
|
||||||
result, ok = store.GetPTRRecord(reverseDomain)
|
result, ok = store.GetPTRRecord(reverseDomain)
|
||||||
|
|||||||
14
go.mod
14
go.mod
@@ -1,14 +1,14 @@
|
|||||||
module github.com/fosrl/olm
|
module github.com/fosrl/olm
|
||||||
|
|
||||||
go 1.25
|
go 1.25.0
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/Microsoft/go-winio v0.6.2
|
github.com/Microsoft/go-winio v0.6.2
|
||||||
github.com/fosrl/newt v1.9.0
|
github.com/fosrl/newt v1.10.3
|
||||||
github.com/godbus/dbus/v5 v5.2.2
|
github.com/godbus/dbus/v5 v5.2.2
|
||||||
github.com/gorilla/websocket v1.5.3
|
github.com/gorilla/websocket v1.5.3
|
||||||
github.com/miekg/dns v1.1.70
|
github.com/miekg/dns v1.1.70
|
||||||
golang.org/x/sys v0.40.0
|
golang.org/x/sys v0.41.0
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
|
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
|
||||||
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c
|
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c
|
||||||
@@ -20,13 +20,13 @@ require (
|
|||||||
github.com/google/go-cmp v0.7.0 // indirect
|
github.com/google/go-cmp v0.7.0 // indirect
|
||||||
github.com/vishvananda/netlink v1.3.1 // indirect
|
github.com/vishvananda/netlink v1.3.1 // indirect
|
||||||
github.com/vishvananda/netns v0.0.5 // indirect
|
github.com/vishvananda/netns v0.0.5 // indirect
|
||||||
golang.org/x/crypto v0.46.0 // indirect
|
golang.org/x/crypto v0.48.0 // indirect
|
||||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect
|
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect
|
||||||
golang.org/x/mod v0.31.0 // indirect
|
golang.org/x/mod v0.32.0 // indirect
|
||||||
golang.org/x/net v0.48.0 // indirect
|
golang.org/x/net v0.51.0 // indirect
|
||||||
golang.org/x/sync v0.19.0 // indirect
|
golang.org/x/sync v0.19.0 // indirect
|
||||||
golang.org/x/time v0.12.0 // indirect
|
golang.org/x/time v0.12.0 // indirect
|
||||||
golang.org/x/tools v0.40.0 // indirect
|
golang.org/x/tools v0.41.0 // indirect
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
|
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
24
go.sum
24
go.sum
@@ -1,7 +1,7 @@
|
|||||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||||
github.com/fosrl/newt v1.9.0 h1:66eJMo6fA+YcBTbddxTfNJXNQo1WWKzmn6zPRP5kSDE=
|
github.com/fosrl/newt v1.10.3 h1:JO9gFK9LP/w2EeDIn4wU+jKggAFPo06hX5hxFSETqcw=
|
||||||
github.com/fosrl/newt v1.9.0/go.mod h1:d1+yYMnKqg4oLqAM9zdbjthjj2FQEVouiACjqU468ck=
|
github.com/fosrl/newt v1.10.3/go.mod h1:iYuuCAG7iabheiogMOX87r61uQN31S39nKxMKRuLS+s=
|
||||||
github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ=
|
github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ=
|
||||||
github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
|
github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
|
||||||
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
|
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
|
||||||
@@ -16,24 +16,24 @@ github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW
|
|||||||
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
|
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
|
||||||
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
|
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
|
||||||
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||||
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
|
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
|
||||||
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
|
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
|
||||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0=
|
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0=
|
||||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0=
|
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0=
|
||||||
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
|
golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
|
||||||
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
|
golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
|
||||||
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
|
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
|
||||||
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
|
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
|
||||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
|
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||||
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
||||||
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||||
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
|
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
|
||||||
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
|
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A=
|
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A=
|
||||||
|
|||||||
14
olm.iss
14
olm.iss
@@ -32,7 +32,7 @@ DefaultGroupName={#MyAppName}
|
|||||||
DisableProgramGroupPage=yes
|
DisableProgramGroupPage=yes
|
||||||
; Uncomment the following line to run in non administrative install mode (install for current user only).
|
; Uncomment the following line to run in non administrative install mode (install for current user only).
|
||||||
;PrivilegesRequired=lowest
|
;PrivilegesRequired=lowest
|
||||||
OutputBaseFilename=mysetup
|
OutputBaseFilename=olm_windows_installer
|
||||||
SolidCompression=yes
|
SolidCompression=yes
|
||||||
WizardStyle=modern
|
WizardStyle=modern
|
||||||
; Add this to ensure PATH changes are applied and the system is prompted for a restart if needed
|
; Add this to ensure PATH changes are applied and the system is prompted for a restart if needed
|
||||||
@@ -78,7 +78,7 @@ begin
|
|||||||
Result := True;
|
Result := True;
|
||||||
exit;
|
exit;
|
||||||
end;
|
end;
|
||||||
|
|
||||||
// Perform a case-insensitive check to see if the path is already present.
|
// Perform a case-insensitive check to see if the path is already present.
|
||||||
// We add semicolons to prevent partial matches (e.g., matching C:\App in C:\App2).
|
// We add semicolons to prevent partial matches (e.g., matching C:\App in C:\App2).
|
||||||
if Pos(';' + UpperCase(Path) + ';', ';' + UpperCase(OrigPath) + ';') > 0 then
|
if Pos(';' + UpperCase(Path) + ';', ';' + UpperCase(OrigPath) + ';') > 0 then
|
||||||
@@ -109,7 +109,7 @@ begin
|
|||||||
PathList.Delimiter := ';';
|
PathList.Delimiter := ';';
|
||||||
PathList.StrictDelimiter := True;
|
PathList.StrictDelimiter := True;
|
||||||
PathList.DelimitedText := OrigPath;
|
PathList.DelimitedText := OrigPath;
|
||||||
|
|
||||||
// Find and remove the matching entry (case-insensitive)
|
// Find and remove the matching entry (case-insensitive)
|
||||||
for I := PathList.Count - 1 downto 0 do
|
for I := PathList.Count - 1 downto 0 do
|
||||||
begin
|
begin
|
||||||
@@ -119,10 +119,10 @@ begin
|
|||||||
PathList.Delete(I);
|
PathList.Delete(I);
|
||||||
end;
|
end;
|
||||||
end;
|
end;
|
||||||
|
|
||||||
// Reconstruct the PATH
|
// Reconstruct the PATH
|
||||||
NewPath := PathList.DelimitedText;
|
NewPath := PathList.DelimitedText;
|
||||||
|
|
||||||
// Write the new PATH back to the registry
|
// Write the new PATH back to the registry
|
||||||
if RegWriteExpandStringValue(HKEY_LOCAL_MACHINE,
|
if RegWriteExpandStringValue(HKEY_LOCAL_MACHINE,
|
||||||
'SYSTEM\CurrentControlSet\Control\Session Manager\Environment',
|
'SYSTEM\CurrentControlSet\Control\Session Manager\Environment',
|
||||||
@@ -145,8 +145,8 @@ begin
|
|||||||
// Get the application installation path
|
// Get the application installation path
|
||||||
AppPath := ExpandConstant('{app}');
|
AppPath := ExpandConstant('{app}');
|
||||||
Log('Removing PATH entry for: ' + AppPath);
|
Log('Removing PATH entry for: ' + AppPath);
|
||||||
|
|
||||||
// Remove only our path entry from the system PATH
|
// Remove only our path entry from the system PATH
|
||||||
RemovePathEntry(AppPath);
|
RemovePathEntry(AppPath);
|
||||||
end;
|
end;
|
||||||
end;
|
end;
|
||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"github.com/fosrl/newt/network"
|
"github.com/fosrl/newt/network"
|
||||||
@@ -168,20 +169,25 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) {
|
|||||||
SharedBind: o.sharedBind,
|
SharedBind: o.sharedBind,
|
||||||
WSClient: o.websocket,
|
WSClient: o.websocket,
|
||||||
APIServer: o.apiServer,
|
APIServer: o.apiServer,
|
||||||
|
PublicDNS: o.tunnelConfig.PublicDNS,
|
||||||
})
|
})
|
||||||
|
|
||||||
for i := range wgData.Sites {
|
for i := range wgData.Sites {
|
||||||
site := wgData.Sites[i]
|
site := wgData.Sites[i]
|
||||||
var siteEndpoint string
|
|
||||||
// here we are going to take the relay endpoint if it exists which means we requested a relay for this peer
|
if site.PublicKey != "" {
|
||||||
if site.RelayEndpoint != "" {
|
var siteEndpoint string
|
||||||
siteEndpoint = site.RelayEndpoint
|
// here we are going to take the relay endpoint if it exists which means we requested a relay for this peer
|
||||||
} else {
|
if site.RelayEndpoint != "" {
|
||||||
siteEndpoint = site.Endpoint
|
siteEndpoint = site.RelayEndpoint
|
||||||
|
} else {
|
||||||
|
siteEndpoint = site.Endpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
o.apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
o.apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false)
|
// we still call this to add the aliases for jit lookup but we just do that then pass inside. need to skip the above so we dont add to the api
|
||||||
|
|
||||||
if err := o.peerManager.AddPeer(site); err != nil {
|
if err := o.peerManager.AddPeer(site); err != nil {
|
||||||
logger.Error("Failed to add peer: %v", err)
|
logger.Error("Failed to add peer: %v", err)
|
||||||
return
|
return
|
||||||
@@ -196,6 +202,36 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) {
|
|||||||
logger.Error("Failed to start DNS proxy: %v", err)
|
logger.Error("Failed to start DNS proxy: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Register JIT handler: when the DNS proxy resolves a local record, check whether
|
||||||
|
// the owning site is already connected and, if not, initiate a JIT connection.
|
||||||
|
o.dnsProxy.SetJITHandler(func(siteId int) {
|
||||||
|
if o.peerManager == nil || o.websocket == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Site already has an active peer connection - nothing to do.
|
||||||
|
if _, exists := o.peerManager.GetPeer(siteId); exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
o.peerSendMu.Lock()
|
||||||
|
defer o.peerSendMu.Unlock()
|
||||||
|
|
||||||
|
// A JIT request for this site is already in-flight - avoid duplicate sends.
|
||||||
|
if _, pending := o.jitPendingSites[siteId]; pending {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
chainId := generateChainId()
|
||||||
|
logger.Info("DNS-triggered JIT connect for site %d (chainId=%s)", siteId, chainId)
|
||||||
|
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/init", map[string]interface{}{
|
||||||
|
"siteId": siteId,
|
||||||
|
"chainId": chainId,
|
||||||
|
}, 2*time.Second, 10)
|
||||||
|
o.stopPeerInits[chainId] = stopFunc
|
||||||
|
o.jitPendingSites[siteId] = chainId
|
||||||
|
})
|
||||||
|
|
||||||
if o.tunnelConfig.OverrideDNS {
|
if o.tunnelConfig.OverrideDNS {
|
||||||
// Set up DNS override to use our DNS proxy
|
// Set up DNS override to use our DNS proxy
|
||||||
if err := dnsOverride.SetupDNSOverride(o.tunnelConfig.InterfaceName, o.dnsProxy.GetProxyIP()); err != nil {
|
if err := dnsOverride.SetupDNSOverride(o.tunnelConfig.InterfaceName, o.dnsProxy.GetProxyIP()); err != nil {
|
||||||
@@ -273,12 +309,12 @@ func (o *Olm) handleTerminate(msg websocket.WSMessage) {
|
|||||||
logger.Error("Error unmarshaling terminate error data: %v", err)
|
logger.Error("Error unmarshaling terminate error data: %v", err)
|
||||||
} else {
|
} else {
|
||||||
logger.Info("Terminate reason (code: %s): %s", errorData.Code, errorData.Message)
|
logger.Info("Terminate reason (code: %s): %s", errorData.Code, errorData.Message)
|
||||||
|
|
||||||
if errorData.Code == "TERMINATED_INACTIVITY" {
|
if errorData.Code == "TERMINATED_INACTIVITY" {
|
||||||
logger.Info("Ignoring...")
|
logger.Info("Ignoring...")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the olm error in the API server so it can be exposed via status
|
// Set the olm error in the API server so it can be exposed via status
|
||||||
o.apiServer.SetOlmError(errorData.Code, errorData.Message)
|
o.apiServer.SetOlmError(errorData.Code, errorData.Message)
|
||||||
}
|
}
|
||||||
|
|||||||
16
olm/data.go
16
olm/data.go
@@ -2,6 +2,7 @@ package olm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fosrl/newt/holepunch"
|
"github.com/fosrl/newt/holepunch"
|
||||||
@@ -220,6 +221,7 @@ func (o *Olm) handleSync(msg websocket.WSMessage) {
|
|||||||
logger.Info("Sync: Adding new peer for site %d", siteId)
|
logger.Info("Sync: Adding new peer for site %d", siteId)
|
||||||
|
|
||||||
o.holePunchManager.TriggerHolePunch()
|
o.holePunchManager.TriggerHolePunch()
|
||||||
|
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
|
||||||
|
|
||||||
// // TODO: do we need to send the message to the cloud to add the peer that way?
|
// // TODO: do we need to send the message to the cloud to add the peer that way?
|
||||||
// if err := o.peerManager.AddPeer(expectedSite); err != nil {
|
// if err := o.peerManager.AddPeer(expectedSite); err != nil {
|
||||||
@@ -230,9 +232,17 @@ func (o *Olm) handleSync(msg websocket.WSMessage) {
|
|||||||
|
|
||||||
// add the peer via the server
|
// add the peer via the server
|
||||||
// this is important because newt needs to get triggered as well to add the peer once the hp is complete
|
// this is important because newt needs to get triggered as well to add the peer once the hp is complete
|
||||||
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
|
chainId := fmt.Sprintf("sync-%d", expectedSite.SiteId)
|
||||||
"siteId": expectedSite.SiteId,
|
o.peerSendMu.Lock()
|
||||||
}, 1*time.Second, 10)
|
if stop, ok := o.stopPeerSends[chainId]; ok {
|
||||||
|
stop()
|
||||||
|
}
|
||||||
|
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
|
||||||
|
"siteId": expectedSite.SiteId,
|
||||||
|
"chainId": chainId,
|
||||||
|
}, 2*time.Second, 10)
|
||||||
|
o.stopPeerSends[chainId] = stopFunc
|
||||||
|
o.peerSendMu.Unlock()
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
// Existing peer - check if update is needed
|
// Existing peer - check if update is needed
|
||||||
|
|||||||
81
olm/olm.go
81
olm/olm.go
@@ -2,6 +2,8 @@ package olm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -31,7 +33,7 @@ type Olm struct {
|
|||||||
privateKey wgtypes.Key
|
privateKey wgtypes.Key
|
||||||
logFile *os.File
|
logFile *os.File
|
||||||
|
|
||||||
registered bool
|
registered bool
|
||||||
tunnelRunning bool
|
tunnelRunning bool
|
||||||
|
|
||||||
uapiListener net.Listener
|
uapiListener net.Listener
|
||||||
@@ -65,7 +67,10 @@ type Olm struct {
|
|||||||
stopRegister func()
|
stopRegister func()
|
||||||
updateRegister func(newData any)
|
updateRegister func(newData any)
|
||||||
|
|
||||||
stopPeerSend func()
|
stopPeerSends map[string]func()
|
||||||
|
stopPeerInits map[string]func()
|
||||||
|
jitPendingSites map[int]string // siteId -> chainId for in-flight JIT requests
|
||||||
|
peerSendMu sync.Mutex
|
||||||
|
|
||||||
// WaitGroup to track tunnel lifecycle
|
// WaitGroup to track tunnel lifecycle
|
||||||
tunnelWg sync.WaitGroup
|
tunnelWg sync.WaitGroup
|
||||||
@@ -111,11 +116,18 @@ func (o *Olm) initTunnelInfo(clientID string) error {
|
|||||||
logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount())
|
logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount())
|
||||||
|
|
||||||
// Create the holepunch manager
|
// Create the holepunch manager
|
||||||
o.holePunchManager = holepunch.NewManager(sharedBind, clientID, "olm", privateKey.PublicKey().String())
|
o.holePunchManager = holepunch.NewManager(sharedBind, clientID, "olm", privateKey.PublicKey().String(), o.tunnelConfig.PublicDNS)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// generateChainId generates a random chain ID for tracking peer sender lifecycles.
|
||||||
|
func generateChainId() string {
|
||||||
|
b := make([]byte, 8)
|
||||||
|
_, _ = rand.Read(b)
|
||||||
|
return hex.EncodeToString(b)
|
||||||
|
}
|
||||||
|
|
||||||
func Init(ctx context.Context, config OlmConfig) (*Olm, error) {
|
func Init(ctx context.Context, config OlmConfig) (*Olm, error) {
|
||||||
logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel))
|
logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel))
|
||||||
|
|
||||||
@@ -166,10 +178,13 @@ func Init(ctx context.Context, config OlmConfig) (*Olm, error) {
|
|||||||
apiServer.SetAgent(config.Agent)
|
apiServer.SetAgent(config.Agent)
|
||||||
|
|
||||||
newOlm := &Olm{
|
newOlm := &Olm{
|
||||||
logFile: logFile,
|
logFile: logFile,
|
||||||
olmCtx: ctx,
|
olmCtx: ctx,
|
||||||
apiServer: apiServer,
|
apiServer: apiServer,
|
||||||
olmConfig: config,
|
olmConfig: config,
|
||||||
|
stopPeerSends: make(map[string]func()),
|
||||||
|
stopPeerInits: make(map[string]func()),
|
||||||
|
jitPendingSites: make(map[int]string),
|
||||||
}
|
}
|
||||||
|
|
||||||
newOlm.registerAPICallbacks()
|
newOlm.registerAPICallbacks()
|
||||||
@@ -222,7 +237,7 @@ func (o *Olm) registerAPICallbacks() {
|
|||||||
tunnelConfig.MTU = 1420
|
tunnelConfig.MTU = 1420
|
||||||
}
|
}
|
||||||
if req.DNS == "" {
|
if req.DNS == "" {
|
||||||
tunnelConfig.DNS = "9.9.9.9"
|
tunnelConfig.DNS = "8.8.8.8"
|
||||||
}
|
}
|
||||||
// DNSProxyIP has no default - it must be provided if DNS proxy is desired
|
// DNSProxyIP has no default - it must be provided if DNS proxy is desired
|
||||||
// UpstreamDNS defaults to 8.8.8.8 if not provided
|
// UpstreamDNS defaults to 8.8.8.8 if not provided
|
||||||
@@ -284,6 +299,21 @@ func (o *Olm) registerAPICallbacks() {
|
|||||||
logger.Info("Processing power mode change request via API: mode=%s", req.Mode)
|
logger.Info("Processing power mode change request via API: mode=%s", req.Mode)
|
||||||
return o.SetPowerMode(req.Mode)
|
return o.SetPowerMode(req.Mode)
|
||||||
},
|
},
|
||||||
|
func(req api.JITConnectionRequest) error {
|
||||||
|
logger.Info("Processing JIT connect request via API: site=%s resource=%s", req.Site, req.Resource)
|
||||||
|
|
||||||
|
chainId := generateChainId()
|
||||||
|
o.peerSendMu.Lock()
|
||||||
|
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/init", map[string]interface{}{
|
||||||
|
"siteId": req.Site,
|
||||||
|
"resourceId": req.Resource,
|
||||||
|
"chainId": chainId,
|
||||||
|
}, 2*time.Second, 10)
|
||||||
|
o.stopPeerInits[chainId] = stopFunc
|
||||||
|
o.peerSendMu.Unlock()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -292,16 +322,23 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
|||||||
logger.Info("Tunnel already running")
|
logger.Info("Tunnel already running")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// debug print out the whole config
|
// debug print out the whole config
|
||||||
logger.Debug("Starting tunnel with config: %+v", config)
|
logger.Debug("Starting tunnel with config: %+v", config)
|
||||||
|
|
||||||
o.tunnelRunning = true // Also set it here in case it is called externally
|
o.tunnelRunning = true // Also set it here in case it is called externally
|
||||||
o.tunnelConfig = config
|
o.tunnelConfig = config
|
||||||
|
|
||||||
|
// TODO: we are hardcoding this for now but we should really pull it from the current config of the system
|
||||||
|
if o.tunnelConfig.DNS != "" {
|
||||||
|
o.tunnelConfig.PublicDNS = []string{o.tunnelConfig.DNS + ":53"}
|
||||||
|
} else {
|
||||||
|
o.tunnelConfig.PublicDNS = []string{"8.8.8.8:53"}
|
||||||
|
}
|
||||||
|
|
||||||
// Reset terminated status when tunnel starts
|
// Reset terminated status when tunnel starts
|
||||||
o.apiServer.SetTerminated(false)
|
o.apiServer.SetTerminated(false)
|
||||||
|
|
||||||
fingerprint := config.InitialFingerprint
|
fingerprint := config.InitialFingerprint
|
||||||
if fingerprint == nil {
|
if fingerprint == nil {
|
||||||
fingerprint = make(map[string]any)
|
fingerprint = make(map[string]any)
|
||||||
@@ -313,7 +350,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
o.SetFingerprint(fingerprint)
|
o.SetFingerprint(fingerprint)
|
||||||
o.SetPostures(postures)
|
o.SetPostures(postures)
|
||||||
|
|
||||||
// Create a cancellable context for this tunnel process
|
// Create a cancellable context for this tunnel process
|
||||||
tunnelCtx, cancel := context.WithCancel(o.olmCtx)
|
tunnelCtx, cancel := context.WithCancel(o.olmCtx)
|
||||||
@@ -378,6 +415,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
|||||||
|
|
||||||
// Handler for peer handshake - adds exit node to holepunch rotation and notifies server
|
// Handler for peer handshake - adds exit node to holepunch rotation and notifies server
|
||||||
o.websocket.RegisterHandler("olm/wg/peer/holepunch/site/add", o.handleWgPeerHolepunchAddSite)
|
o.websocket.RegisterHandler("olm/wg/peer/holepunch/site/add", o.handleWgPeerHolepunchAddSite)
|
||||||
|
o.websocket.RegisterHandler("olm/wg/peer/chain/cancel", o.handleCancelChain)
|
||||||
o.websocket.RegisterHandler("olm/sync", o.handleSync)
|
o.websocket.RegisterHandler("olm/sync", o.handleSync)
|
||||||
|
|
||||||
o.websocket.OnConnect(func() error {
|
o.websocket.OnConnect(func() error {
|
||||||
@@ -387,7 +425,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
|||||||
|
|
||||||
if o.registered {
|
if o.registered {
|
||||||
o.websocket.StartPingMonitor()
|
o.websocket.StartPingMonitor()
|
||||||
|
|
||||||
logger.Debug("Already registered, skipping registration")
|
logger.Debug("Already registered, skipping registration")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -420,7 +458,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
|||||||
"userToken": userToken,
|
"userToken": userToken,
|
||||||
"fingerprint": o.fingerprint,
|
"fingerprint": o.fingerprint,
|
||||||
"postures": o.postures,
|
"postures": o.postures,
|
||||||
}, 1*time.Second, 10)
|
}, 2*time.Second, 10)
|
||||||
|
|
||||||
// Invoke onRegistered callback if configured
|
// Invoke onRegistered callback if configured
|
||||||
if o.olmConfig.OnRegistered != nil {
|
if o.olmConfig.OnRegistered != nil {
|
||||||
@@ -517,6 +555,23 @@ func (o *Olm) Close() {
|
|||||||
o.stopRegister = nil
|
o.stopRegister = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stop all pending peer init and send senders before closing websocket
|
||||||
|
o.peerSendMu.Lock()
|
||||||
|
for _, stop := range o.stopPeerInits {
|
||||||
|
if stop != nil {
|
||||||
|
stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
o.stopPeerInits = make(map[string]func())
|
||||||
|
for _, stop := range o.stopPeerSends {
|
||||||
|
if stop != nil {
|
||||||
|
stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
o.stopPeerSends = make(map[string]func())
|
||||||
|
o.jitPendingSites = make(map[int]string)
|
||||||
|
o.peerSendMu.Unlock()
|
||||||
|
|
||||||
// send a disconnect message to the cloud to show disconnected
|
// send a disconnect message to the cloud to show disconnected
|
||||||
if o.websocket != nil {
|
if o.websocket != nil {
|
||||||
o.websocket.SendMessage("olm/disconnecting", map[string]any{})
|
o.websocket.SendMessage("olm/disconnecting", map[string]any{})
|
||||||
|
|||||||
140
olm/peer.go
140
olm/peer.go
@@ -20,31 +20,43 @@ func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if o.stopPeerSend != nil {
|
|
||||||
o.stopPeerSend()
|
|
||||||
o.stopPeerSend = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonData, err := json.Marshal(msg.Data)
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Error marshaling data: %v", err)
|
logger.Error("Error marshaling data: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var siteConfig peers.SiteConfig
|
var siteConfigMsg struct {
|
||||||
if err := json.Unmarshal(jsonData, &siteConfig); err != nil {
|
peers.SiteConfig
|
||||||
|
ChainId string `json:"chainId"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(jsonData, &siteConfigMsg); err != nil {
|
||||||
logger.Error("Error unmarshaling add data: %v", err)
|
logger.Error("Error unmarshaling add data: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if siteConfigMsg.ChainId != "" {
|
||||||
|
o.peerSendMu.Lock()
|
||||||
|
if stop, ok := o.stopPeerSends[siteConfigMsg.ChainId]; ok {
|
||||||
|
stop()
|
||||||
|
delete(o.stopPeerSends, siteConfigMsg.ChainId)
|
||||||
|
}
|
||||||
|
o.peerSendMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
if siteConfigMsg.PublicKey == "" {
|
||||||
|
logger.Warn("Skipping add-peer for site %d (%s): no public key available (site may not be connected)", siteConfigMsg.SiteId, siteConfigMsg.Name)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
_ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it
|
_ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it
|
||||||
|
|
||||||
if err := o.peerManager.AddPeer(siteConfig); err != nil {
|
if err := o.peerManager.AddPeer(siteConfigMsg.SiteConfig); err != nil {
|
||||||
logger.Error("Failed to add peer: %v", err)
|
logger.Error("Failed to add peer: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("Successfully added peer for site %d", siteConfig.SiteId)
|
logger.Info("Successfully added peer for site %d", siteConfigMsg.SiteId)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) {
|
func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) {
|
||||||
@@ -164,13 +176,21 @@ func (o *Olm) handleWgPeerRelay(msg websocket.WSMessage) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var relayData peers.RelayPeerData
|
var relayData struct {
|
||||||
|
peers.RelayPeerData
|
||||||
|
ChainId string `json:"chainId"`
|
||||||
|
}
|
||||||
if err := json.Unmarshal(jsonData, &relayData); err != nil {
|
if err := json.Unmarshal(jsonData, &relayData); err != nil {
|
||||||
logger.Error("Error unmarshaling relay data: %v", err)
|
logger.Error("Error unmarshaling relay data: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint)
|
if monitor := o.peerManager.GetPeerMonitor(); monitor != nil {
|
||||||
|
monitor.CancelRelaySend(relayData.ChainId)
|
||||||
|
}
|
||||||
|
|
||||||
|
primaryRelay, err := util.ResolveDomainUpstream(relayData.RelayEndpoint, o.tunnelConfig.PublicDNS)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to resolve primary relay endpoint: %v", err)
|
logger.Error("Failed to resolve primary relay endpoint: %v", err)
|
||||||
return
|
return
|
||||||
@@ -197,13 +217,21 @@ func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var relayData peers.UnRelayPeerData
|
var relayData struct {
|
||||||
|
peers.UnRelayPeerData
|
||||||
|
ChainId string `json:"chainId"`
|
||||||
|
}
|
||||||
if err := json.Unmarshal(jsonData, &relayData); err != nil {
|
if err := json.Unmarshal(jsonData, &relayData); err != nil {
|
||||||
logger.Error("Error unmarshaling relay data: %v", err)
|
logger.Error("Error unmarshaling relay data: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
primaryRelay, err := util.ResolveDomain(relayData.Endpoint)
|
if monitor := o.peerManager.GetPeerMonitor(); monitor != nil {
|
||||||
|
monitor.CancelRelaySend(relayData.ChainId)
|
||||||
|
}
|
||||||
|
|
||||||
|
primaryRelay, err := util.ResolveDomainUpstream(relayData.Endpoint, o.tunnelConfig.PublicDNS)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
|
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
|
||||||
}
|
}
|
||||||
@@ -230,7 +258,8 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var handshakeData struct {
|
var handshakeData struct {
|
||||||
SiteId int `json:"siteId"`
|
SiteId int `json:"siteId"`
|
||||||
|
ChainId string `json:"chainId"`
|
||||||
ExitNode struct {
|
ExitNode struct {
|
||||||
PublicKey string `json:"publicKey"`
|
PublicKey string `json:"publicKey"`
|
||||||
Endpoint string `json:"endpoint"`
|
Endpoint string `json:"endpoint"`
|
||||||
@@ -243,6 +272,19 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stop the peer init sender for this chain, if any
|
||||||
|
if handshakeData.ChainId != "" {
|
||||||
|
o.peerSendMu.Lock()
|
||||||
|
if stop, ok := o.stopPeerInits[handshakeData.ChainId]; ok {
|
||||||
|
stop()
|
||||||
|
delete(o.stopPeerInits, handshakeData.ChainId)
|
||||||
|
}
|
||||||
|
// If this chain was initiated by a DNS-triggered JIT request, clear the
|
||||||
|
// pending entry so the site can be re-triggered if needed in the future.
|
||||||
|
delete(o.jitPendingSites, handshakeData.SiteId)
|
||||||
|
o.peerSendMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
// Get existing peer from PeerManager
|
// Get existing peer from PeerManager
|
||||||
_, exists := o.peerManager.GetPeer(handshakeData.SiteId)
|
_, exists := o.peerManager.GetPeer(handshakeData.SiteId)
|
||||||
if exists {
|
if exists {
|
||||||
@@ -273,10 +315,72 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
|
|||||||
o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt
|
o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt
|
||||||
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
|
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
|
||||||
|
|
||||||
// Send handshake acknowledgment back to server with retry
|
// Send handshake acknowledgment back to server with retry, keyed by chainId
|
||||||
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
|
chainId := handshakeData.ChainId
|
||||||
"siteId": handshakeData.SiteId,
|
if chainId == "" {
|
||||||
}, 1*time.Second, 10)
|
chainId = generateChainId()
|
||||||
|
}
|
||||||
|
o.peerSendMu.Lock()
|
||||||
|
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
|
||||||
|
"siteId": handshakeData.SiteId,
|
||||||
|
"chainId": chainId,
|
||||||
|
}, 2*time.Second, 10)
|
||||||
|
o.stopPeerSends[chainId] = stopFunc
|
||||||
|
o.peerSendMu.Unlock()
|
||||||
|
|
||||||
logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint)
|
logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o *Olm) handleCancelChain(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received cancel-chain message: %v", msg.Data)
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error marshaling cancel-chain data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var cancelData struct {
|
||||||
|
ChainId string `json:"chainId"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(jsonData, &cancelData); err != nil {
|
||||||
|
logger.Error("Error unmarshaling cancel-chain data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if cancelData.ChainId == "" {
|
||||||
|
logger.Warn("Received cancel-chain message with no chainId")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
o.peerSendMu.Lock()
|
||||||
|
defer o.peerSendMu.Unlock()
|
||||||
|
|
||||||
|
found := false
|
||||||
|
|
||||||
|
if stop, ok := o.stopPeerInits[cancelData.ChainId]; ok {
|
||||||
|
stop()
|
||||||
|
delete(o.stopPeerInits, cancelData.ChainId)
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
// If this chain was a DNS-triggered JIT request, clear the pending entry so
|
||||||
|
// the site can be re-triggered on the next DNS lookup.
|
||||||
|
for siteId, chainId := range o.jitPendingSites {
|
||||||
|
if chainId == cancelData.ChainId {
|
||||||
|
delete(o.jitPendingSites, siteId)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if stop, ok := o.stopPeerSends[cancelData.ChainId]; ok {
|
||||||
|
stop()
|
||||||
|
delete(o.stopPeerSends, cancelData.ChainId)
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if found {
|
||||||
|
logger.Info("Cancelled chain %s", cancelData.ChainId)
|
||||||
|
} else {
|
||||||
|
logger.Warn("Cancel-chain: no active sender found for chain %s", cancelData.ChainId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -61,6 +61,7 @@ type TunnelConfig struct {
|
|||||||
MTU int
|
MTU int
|
||||||
DNS string
|
DNS string
|
||||||
UpstreamDNS []string
|
UpstreamDNS []string
|
||||||
|
PublicDNS []string
|
||||||
InterfaceName string
|
InterfaceName string
|
||||||
|
|
||||||
// Advanced
|
// Advanced
|
||||||
|
|||||||
@@ -32,7 +32,8 @@ type PeerManagerConfig struct {
|
|||||||
SharedBind *bind.SharedBind
|
SharedBind *bind.SharedBind
|
||||||
// WSClient is optional - if nil, relay messages won't be sent
|
// WSClient is optional - if nil, relay messages won't be sent
|
||||||
WSClient *websocket.Client
|
WSClient *websocket.Client
|
||||||
APIServer *api.API
|
APIServer *api.API
|
||||||
|
PublicDNS []string
|
||||||
}
|
}
|
||||||
|
|
||||||
type PeerManager struct {
|
type PeerManager struct {
|
||||||
@@ -50,7 +51,8 @@ type PeerManager struct {
|
|||||||
// key is the CIDR string, value is a set of siteIds that want this IP
|
// key is the CIDR string, value is a set of siteIds that want this IP
|
||||||
allowedIPClaims map[string]map[int]bool
|
allowedIPClaims map[string]map[int]bool
|
||||||
APIServer *api.API
|
APIServer *api.API
|
||||||
|
publicDNS []string
|
||||||
|
|
||||||
PersistentKeepalive int
|
PersistentKeepalive int
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -65,6 +67,7 @@ func NewPeerManager(config PeerManagerConfig) *PeerManager {
|
|||||||
allowedIPOwners: make(map[string]int),
|
allowedIPOwners: make(map[string]int),
|
||||||
allowedIPClaims: make(map[string]map[int]bool),
|
allowedIPClaims: make(map[string]map[int]bool),
|
||||||
APIServer: config.APIServer,
|
APIServer: config.APIServer,
|
||||||
|
publicDNS: config.PublicDNS,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the peer monitor
|
// Create the peer monitor
|
||||||
@@ -74,6 +77,7 @@ func NewPeerManager(config PeerManagerConfig) *PeerManager {
|
|||||||
config.LocalIP,
|
config.LocalIP,
|
||||||
config.SharedBind,
|
config.SharedBind,
|
||||||
config.APIServer,
|
config.APIServer,
|
||||||
|
config.PublicDNS,
|
||||||
)
|
)
|
||||||
|
|
||||||
return pm
|
return pm
|
||||||
@@ -106,6 +110,19 @@ func (pm *PeerManager) GetAllPeers() []SiteConfig {
|
|||||||
func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
|
func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
|
||||||
pm.mu.Lock()
|
pm.mu.Lock()
|
||||||
defer pm.mu.Unlock()
|
defer pm.mu.Unlock()
|
||||||
|
|
||||||
|
for _, alias := range siteConfig.Aliases {
|
||||||
|
address := net.ParseIP(alias.AliasAddress)
|
||||||
|
if address == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
pm.dnsProxy.AddDNSRecord(alias.Alias, address, siteConfig.SiteId)
|
||||||
|
}
|
||||||
|
|
||||||
|
if siteConfig.PublicKey == "" {
|
||||||
|
logger.Debug("Skip adding site %d because no pub key", siteConfig.SiteId)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// build the allowed IPs list from the remote subnets and aliases and add them to the peer
|
// build the allowed IPs list from the remote subnets and aliases and add them to the peer
|
||||||
allowedIPs := make([]string, 0, len(siteConfig.RemoteSubnets)+len(siteConfig.Aliases))
|
allowedIPs := make([]string, 0, len(siteConfig.RemoteSubnets)+len(siteConfig.Aliases))
|
||||||
@@ -129,7 +146,7 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
|
|||||||
wgConfig := siteConfig
|
wgConfig := siteConfig
|
||||||
wgConfig.AllowedIps = ownedIPs
|
wgConfig.AllowedIps = ownedIPs
|
||||||
|
|
||||||
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil {
|
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive, pm.publicDNS); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -139,14 +156,7 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
|
|||||||
if err := network.AddRoutes(siteConfig.RemoteSubnets, pm.interfaceName); err != nil {
|
if err := network.AddRoutes(siteConfig.RemoteSubnets, pm.interfaceName); err != nil {
|
||||||
logger.Error("Failed to add routes for remote subnets: %v", err)
|
logger.Error("Failed to add routes for remote subnets: %v", err)
|
||||||
}
|
}
|
||||||
for _, alias := range siteConfig.Aliases {
|
|
||||||
address := net.ParseIP(alias.AliasAddress)
|
|
||||||
if address == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
pm.dnsProxy.AddDNSRecord(alias.Alias, address)
|
|
||||||
}
|
|
||||||
|
|
||||||
monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0]
|
monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0]
|
||||||
monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port
|
monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port
|
||||||
|
|
||||||
@@ -270,7 +280,7 @@ func (pm *PeerManager) RemovePeer(siteId int) error {
|
|||||||
ownedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
|
ownedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
|
||||||
wgConfig := promotedPeer
|
wgConfig := promotedPeer
|
||||||
wgConfig.AllowedIps = ownedIPs
|
wgConfig.AllowedIps = ownedIPs
|
||||||
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil {
|
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive, pm.publicDNS); err != nil {
|
||||||
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
|
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -346,7 +356,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
|
|||||||
wgConfig := siteConfig
|
wgConfig := siteConfig
|
||||||
wgConfig.AllowedIps = ownedIPs
|
wgConfig.AllowedIps = ownedIPs
|
||||||
|
|
||||||
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil {
|
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive, pm.publicDNS); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -356,7 +366,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
|
|||||||
promotedOwnedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
|
promotedOwnedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
|
||||||
promotedWgConfig := promotedPeer
|
promotedWgConfig := promotedPeer
|
||||||
promotedWgConfig.AllowedIps = promotedOwnedIPs
|
promotedWgConfig.AllowedIps = promotedOwnedIPs
|
||||||
if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil {
|
if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive, pm.publicDNS); err != nil {
|
||||||
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
|
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -433,7 +443,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
|
|||||||
if address == nil {
|
if address == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
pm.dnsProxy.AddDNSRecord(alias.Alias, address)
|
pm.dnsProxy.AddDNSRecord(alias.Alias, address, siteConfig.SiteId)
|
||||||
}
|
}
|
||||||
|
|
||||||
pm.peerMonitor.UpdateHolepunchEndpoint(siteConfig.SiteId, siteConfig.Endpoint)
|
pm.peerMonitor.UpdateHolepunchEndpoint(siteConfig.SiteId, siteConfig.Endpoint)
|
||||||
@@ -713,7 +723,7 @@ func (pm *PeerManager) AddAlias(siteId int, alias Alias) error {
|
|||||||
|
|
||||||
address := net.ParseIP(alias.AliasAddress)
|
address := net.ParseIP(alias.AliasAddress)
|
||||||
if address != nil {
|
if address != nil {
|
||||||
pm.dnsProxy.AddDNSRecord(alias.Alias, address)
|
pm.dnsProxy.AddDNSRecord(alias.Alias, address, siteId)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add an allowed IP for the alias
|
// Add an allowed IP for the alias
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package monitor
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -31,9 +33,14 @@ type PeerMonitor struct {
|
|||||||
monitors map[int]*Client
|
monitors map[int]*Client
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
running bool
|
running bool
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
maxAttempts int
|
maxAttempts int
|
||||||
wsClient *websocket.Client
|
wsClient *websocket.Client
|
||||||
|
publicDNS []string
|
||||||
|
|
||||||
|
// Relay sender tracking
|
||||||
|
relaySends map[string]func()
|
||||||
|
relaySendMu sync.Mutex
|
||||||
|
|
||||||
// Netstack fields
|
// Netstack fields
|
||||||
middleDev *middleDevice.MiddleDevice
|
middleDev *middleDevice.MiddleDevice
|
||||||
@@ -47,13 +54,13 @@ type PeerMonitor struct {
|
|||||||
nsWg sync.WaitGroup
|
nsWg sync.WaitGroup
|
||||||
|
|
||||||
// Holepunch testing fields
|
// Holepunch testing fields
|
||||||
sharedBind *bind.SharedBind
|
sharedBind *bind.SharedBind
|
||||||
holepunchTester *holepunch.HolepunchTester
|
holepunchTester *holepunch.HolepunchTester
|
||||||
holepunchTimeout time.Duration
|
holepunchTimeout time.Duration
|
||||||
holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing
|
holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing
|
||||||
holepunchStatus map[int]bool // siteID -> connected status
|
holepunchStatus map[int]bool // siteID -> connected status
|
||||||
holepunchStopChan chan struct{}
|
holepunchStopChan chan struct{}
|
||||||
holepunchUpdateChan chan struct{}
|
holepunchUpdateChan chan struct{}
|
||||||
|
|
||||||
// Relay tracking fields
|
// Relay tracking fields
|
||||||
relayedPeers map[int]bool // siteID -> whether the peer is currently relayed
|
relayedPeers map[int]bool // siteID -> whether the peer is currently relayed
|
||||||
@@ -82,7 +89,13 @@ type PeerMonitor struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewPeerMonitor creates a new peer monitor with the given callback
|
// NewPeerMonitor creates a new peer monitor with the given callback
|
||||||
func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind, apiServer *api.API) *PeerMonitor {
|
func generateChainId() string {
|
||||||
|
b := make([]byte, 8)
|
||||||
|
_, _ = rand.Read(b)
|
||||||
|
return hex.EncodeToString(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind, apiServer *api.API, publicDNS []string) *PeerMonitor {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
pm := &PeerMonitor{
|
pm := &PeerMonitor{
|
||||||
monitors: make(map[int]*Client),
|
monitors: make(map[int]*Client),
|
||||||
@@ -91,6 +104,7 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
|
|||||||
wsClient: wsClient,
|
wsClient: wsClient,
|
||||||
middleDev: middleDev,
|
middleDev: middleDev,
|
||||||
localIP: localIP,
|
localIP: localIP,
|
||||||
|
publicDNS: publicDNS,
|
||||||
activePorts: make(map[uint16]bool),
|
activePorts: make(map[uint16]bool),
|
||||||
nsCtx: ctx,
|
nsCtx: ctx,
|
||||||
nsCancel: cancel,
|
nsCancel: cancel,
|
||||||
@@ -99,6 +113,7 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
|
|||||||
holepunchEndpoints: make(map[int]string),
|
holepunchEndpoints: make(map[int]string),
|
||||||
holepunchStatus: make(map[int]bool),
|
holepunchStatus: make(map[int]bool),
|
||||||
relayedPeers: make(map[int]bool),
|
relayedPeers: make(map[int]bool),
|
||||||
|
relaySends: make(map[string]func()),
|
||||||
holepunchMaxAttempts: 3, // Trigger relay after 3 consecutive failures
|
holepunchMaxAttempts: 3, // Trigger relay after 3 consecutive failures
|
||||||
holepunchFailures: make(map[int]int),
|
holepunchFailures: make(map[int]int),
|
||||||
// Rapid initial test settings: complete within ~1.5 seconds
|
// Rapid initial test settings: complete within ~1.5 seconds
|
||||||
@@ -124,7 +139,7 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
|
|||||||
|
|
||||||
// Initialize holepunch tester if sharedBind is available
|
// Initialize holepunch tester if sharedBind is available
|
||||||
if sharedBind != nil {
|
if sharedBind != nil {
|
||||||
pm.holepunchTester = holepunch.NewHolepunchTester(sharedBind)
|
pm.holepunchTester = holepunch.NewHolepunchTester(sharedBind, publicDNS)
|
||||||
}
|
}
|
||||||
|
|
||||||
return pm
|
return pm
|
||||||
@@ -396,20 +411,23 @@ func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status Connectio
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendRelay sends a relay message to the server
|
// sendRelay sends a relay message to the server with retry, keyed by chainId
|
||||||
func (pm *PeerMonitor) sendRelay(siteID int) error {
|
func (pm *PeerMonitor) sendRelay(siteID int) error {
|
||||||
if pm.wsClient == nil {
|
if pm.wsClient == nil {
|
||||||
return fmt.Errorf("websocket client is nil")
|
return fmt.Errorf("websocket client is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
err := pm.wsClient.SendMessage("olm/wg/relay", map[string]interface{}{
|
chainId := generateChainId()
|
||||||
"siteId": siteID,
|
stopFunc, _ := pm.wsClient.SendMessageInterval("olm/wg/relay", map[string]interface{}{
|
||||||
})
|
"siteId": siteID,
|
||||||
if err != nil {
|
"chainId": chainId,
|
||||||
logger.Error("Failed to send registration message: %v", err)
|
}, 2*time.Second, 10)
|
||||||
return err
|
|
||||||
}
|
pm.relaySendMu.Lock()
|
||||||
logger.Info("Sent relay message")
|
pm.relaySends[chainId] = stopFunc
|
||||||
|
pm.relaySendMu.Unlock()
|
||||||
|
|
||||||
|
logger.Info("Sent relay message for site %d (chain %s)", siteID, chainId)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -419,23 +437,52 @@ func (pm *PeerMonitor) RequestRelay(siteID int) error {
|
|||||||
return pm.sendRelay(siteID)
|
return pm.sendRelay(siteID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendUnRelay sends an unrelay message to the server
|
// sendUnRelay sends an unrelay message to the server with retry, keyed by chainId
|
||||||
func (pm *PeerMonitor) sendUnRelay(siteID int) error {
|
func (pm *PeerMonitor) sendUnRelay(siteID int) error {
|
||||||
if pm.wsClient == nil {
|
if pm.wsClient == nil {
|
||||||
return fmt.Errorf("websocket client is nil")
|
return fmt.Errorf("websocket client is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
err := pm.wsClient.SendMessage("olm/wg/unrelay", map[string]interface{}{
|
chainId := generateChainId()
|
||||||
"siteId": siteID,
|
stopFunc, _ := pm.wsClient.SendMessageInterval("olm/wg/unrelay", map[string]interface{}{
|
||||||
})
|
"siteId": siteID,
|
||||||
if err != nil {
|
"chainId": chainId,
|
||||||
logger.Error("Failed to send registration message: %v", err)
|
}, 2*time.Second, 10)
|
||||||
return err
|
|
||||||
}
|
pm.relaySendMu.Lock()
|
||||||
logger.Info("Sent unrelay message")
|
pm.relaySends[chainId] = stopFunc
|
||||||
|
pm.relaySendMu.Unlock()
|
||||||
|
|
||||||
|
logger.Info("Sent unrelay message for site %d (chain %s)", siteID, chainId)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CancelRelaySend stops the interval sender for the given chainId, if one exists.
|
||||||
|
// If chainId is empty, all active relay senders are stopped.
|
||||||
|
func (pm *PeerMonitor) CancelRelaySend(chainId string) {
|
||||||
|
pm.relaySendMu.Lock()
|
||||||
|
defer pm.relaySendMu.Unlock()
|
||||||
|
|
||||||
|
if chainId == "" {
|
||||||
|
for id, stop := range pm.relaySends {
|
||||||
|
if stop != nil {
|
||||||
|
stop()
|
||||||
|
}
|
||||||
|
delete(pm.relaySends, id)
|
||||||
|
}
|
||||||
|
logger.Info("Cancelled all relay senders")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if stop, ok := pm.relaySends[chainId]; ok {
|
||||||
|
stop()
|
||||||
|
delete(pm.relaySends, chainId)
|
||||||
|
logger.Info("Cancelled relay sender for chain %s", chainId)
|
||||||
|
} else {
|
||||||
|
logger.Warn("CancelRelaySend: no active sender for chain %s", chainId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Stop stops monitoring all peers
|
// Stop stops monitoring all peers
|
||||||
func (pm *PeerMonitor) Stop() {
|
func (pm *PeerMonitor) Stop() {
|
||||||
// Stop holepunch monitor first (outside of mutex to avoid deadlock)
|
// Stop holepunch monitor first (outside of mutex to avoid deadlock)
|
||||||
@@ -534,7 +581,7 @@ func (pm *PeerMonitor) runHolepunchMonitor() {
|
|||||||
pm.holepunchCurrentInterval = pm.holepunchMinInterval
|
pm.holepunchCurrentInterval = pm.holepunchMinInterval
|
||||||
currentInterval := pm.holepunchCurrentInterval
|
currentInterval := pm.holepunchCurrentInterval
|
||||||
pm.mutex.Unlock()
|
pm.mutex.Unlock()
|
||||||
|
|
||||||
timer.Reset(currentInterval)
|
timer.Reset(currentInterval)
|
||||||
logger.Debug("Holepunch monitor interval updated, reset to %v", currentInterval)
|
logger.Debug("Holepunch monitor interval updated, reset to %v", currentInterval)
|
||||||
case <-timer.C:
|
case <-timer.C:
|
||||||
@@ -677,6 +724,16 @@ func (pm *PeerMonitor) Close() {
|
|||||||
// Stop holepunch monitor first (outside of mutex to avoid deadlock)
|
// Stop holepunch monitor first (outside of mutex to avoid deadlock)
|
||||||
pm.stopHolepunchMonitor()
|
pm.stopHolepunchMonitor()
|
||||||
|
|
||||||
|
// Stop all pending relay senders
|
||||||
|
pm.relaySendMu.Lock()
|
||||||
|
for chainId, stop := range pm.relaySends {
|
||||||
|
if stop != nil {
|
||||||
|
stop()
|
||||||
|
}
|
||||||
|
delete(pm.relaySends, chainId)
|
||||||
|
}
|
||||||
|
pm.relaySendMu.Unlock()
|
||||||
|
|
||||||
pm.mutex.Lock()
|
pm.mutex.Lock()
|
||||||
defer pm.mutex.Unlock()
|
defer pm.mutex.Unlock()
|
||||||
|
|
||||||
|
|||||||
@@ -11,14 +11,14 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// ConfigurePeer sets up or updates a peer within the WireGuard device
|
// ConfigurePeer sets up or updates a peer within the WireGuard device
|
||||||
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool, persistentKeepalive int) error {
|
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool, persistentKeepalive int, publicDNS []string) error {
|
||||||
var endpoint string
|
var endpoint string
|
||||||
if relay && siteConfig.RelayEndpoint != "" {
|
if relay && siteConfig.RelayEndpoint != "" {
|
||||||
endpoint = formatEndpoint(siteConfig.RelayEndpoint)
|
endpoint = formatEndpoint(siteConfig.RelayEndpoint)
|
||||||
} else {
|
} else {
|
||||||
endpoint = formatEndpoint(siteConfig.Endpoint)
|
endpoint = formatEndpoint(siteConfig.Endpoint)
|
||||||
}
|
}
|
||||||
siteHost, err := util.ResolveDomain(endpoint)
|
siteHost, err := util.ResolveDomainUpstream(endpoint, publicDNS)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err)
|
return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -388,6 +388,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
|
|||||||
tokenData := map[string]interface{}{
|
tokenData := map[string]interface{}{
|
||||||
"olmId": c.config.ID,
|
"olmId": c.config.ID,
|
||||||
"secret": c.config.Secret,
|
"secret": c.config.Secret,
|
||||||
|
"userToken": c.config.UserToken,
|
||||||
"orgId": c.config.OrgID,
|
"orgId": c.config.OrgID,
|
||||||
}
|
}
|
||||||
jsonData, err := json.Marshal(tokenData)
|
jsonData, err := json.Marshal(tokenData)
|
||||||
|
|||||||
Reference in New Issue
Block a user