Compare commits

...

12 Commits

Author SHA1 Message Date
Owen
da0ad21fd4 Update test
Former-commit-id: 449e631aae
2025-12-21 21:07:14 -05:00
Owen
44c8d871c2 Build binaries and do release
Former-commit-id: 8aaefde72a
2025-12-21 21:04:24 -05:00
Owen
96a88057f9 Update mod
Former-commit-id: b026bea86e
2025-12-21 21:03:48 -05:00
Owen
d96fe6391e Remove replace
Former-commit-id: 5551eff130
2025-12-21 21:03:48 -05:00
Owen
fe7fd31955 Sending DNS over the tunnel works
Former-commit-id: ca763fff2d
2025-12-21 21:03:48 -05:00
Owen
86b19f243e Remove exit nodes from HPing if peers are removed
Former-commit-id: 0c96d3c25c
2025-12-21 21:03:48 -05:00
Owen
d0940d03c4 Cleanup unclean shutdown cross platform
Former-commit-id: c1a2efd9d2
2025-12-21 21:03:48 -05:00
Owen Schwartz
708c761fa6 Merge pull request #63 from fosrl/wildcard-resources
Wildcard resources

Former-commit-id: ee8b93b13a
2025-12-16 21:50:22 -05:00
Owen
78dc6508a4 Support wildcard alias records
Former-commit-id: cec79bf014
2025-12-16 21:33:41 -05:00
Owen
7f6c824122 Pull 21820 from config
Former-commit-id: 56f4614899
2025-12-16 18:36:04 -05:00
Owen
9ba3569573 Remove acciential file
Former-commit-id: c3a12bd2a9
2025-12-11 23:26:21 -05:00
Owen
fd38f4cc59 Fix test
Former-commit-id: e0efe8f950
2025-12-11 23:25:56 -05:00
22 changed files with 1205 additions and 98 deletions

View File

@@ -586,28 +586,28 @@ jobs:
# sarif_file: trivy-ghcr.sarif
# category: Image Vulnerability Scan
# - name: Build binaries
# env:
# CGO_ENABLED: "0"
# GOFLAGS: "-trimpath"
# run: |
# set -euo pipefail
# TAG_VAR="${TAG}"
# make go-build-release tag=$TAG_VAR
# shell: bash
- name: Build binaries
env:
CGO_ENABLED: "0"
GOFLAGS: "-trimpath"
run: |
set -euo pipefail
TAG_VAR="${TAG}"
make go-build-release tag=$TAG_VAR
shell: bash
# - name: Create GitHub Release
# uses: softprops/action-gh-release@5be0e66d93ac7ed76da52eca8bb058f665c3a5fe # v2.4.2
# with:
# tag_name: ${{ env.TAG }}
# generate_release_notes: true
# prerelease: ${{ env.IS_RC == 'true' }}
# files: |
# bin/*
# fail_on_unmatched_files: true
# draft: true
# body: |
# ## Container Images
# - GHCR: `${{ env.GHCR_REF }}`
# - Docker Hub: `${{ env.DH_REF || 'N/A' }}`
# **Digest:** `${{ steps.build.outputs.digest }}`
- name: Create GitHub Release
uses: softprops/action-gh-release@5be0e66d93ac7ed76da52eca8bb058f665c3a5fe # v2.4.2
with:
tag_name: ${{ env.TAG }}
generate_release_notes: true
prerelease: ${{ env.IS_RC == 'true' }}
files: |
bin/*
fail_on_unmatched_files: true
draft: true
body: |
## Container Images
- GHCR: `${{ env.GHCR_REF }}`
- Docker Hub: `${{ env.DH_REF || 'N/A' }}`
**Digest:** `${{ steps.build.outputs.digest }}`

View File

@@ -11,24 +11,15 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- name: Clone fosrl/newt
uses: actions/checkout@v6
with:
repository: fosrl/newt
path: ../newt
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Set up Go
uses: actions/setup-go@v6
uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # v6.1.0
with:
go-version: 1.25
- name: Build go
run: go build
- name: Build Docker image
run: make build
- name: Build binaries
run: make go-build-release
- name: Build Docker image
run: make docker-build-release

View File

@@ -43,6 +43,7 @@ type OlmConfig struct {
DisableHolepunch bool `json:"disableHolepunch"`
TlsClientCert string `json:"tlsClientCert"`
OverrideDNS bool `json:"overrideDNS"`
TunnelDNS bool `json:"tunnelDNS"`
DisableRelay bool `json:"disableRelay"`
// DoNotCreateNewClient bool `json:"doNotCreateNewClient"`
@@ -88,6 +89,7 @@ func DefaultConfig() *OlmConfig {
PingInterval: "3s",
PingTimeout: "5s",
DisableHolepunch: false,
TunnelDNS: false,
// DoNotCreateNewClient: false,
sources: make(map[string]string),
}
@@ -105,6 +107,7 @@ func DefaultConfig() *OlmConfig {
config.sources["pingTimeout"] = string(SourceDefault)
config.sources["disableHolepunch"] = string(SourceDefault)
config.sources["overrideDNS"] = string(SourceDefault)
config.sources["tunnelDNS"] = string(SourceDefault)
config.sources["disableRelay"] = string(SourceDefault)
// config.sources["doNotCreateNewClient"] = string(SourceDefault)
@@ -265,6 +268,10 @@ func loadConfigFromEnv(config *OlmConfig) {
config.DisableRelay = true
config.sources["disableRelay"] = string(SourceEnv)
}
if val := os.Getenv("TUNNEL_DNS"); val == "true" {
config.TunnelDNS = true
config.sources["tunnelDNS"] = string(SourceEnv)
}
// if val := os.Getenv("DO_NOT_CREATE_NEW_CLIENT"); val == "true" {
// config.DoNotCreateNewClient = true
// config.sources["doNotCreateNewClient"] = string(SourceEnv)
@@ -295,6 +302,7 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
"disableHolepunch": config.DisableHolepunch,
"overrideDNS": config.OverrideDNS,
"disableRelay": config.DisableRelay,
"tunnelDNS": config.TunnelDNS,
// "doNotCreateNewClient": config.DoNotCreateNewClient,
}
@@ -318,6 +326,7 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
serviceFlags.BoolVar(&config.DisableHolepunch, "disable-holepunch", config.DisableHolepunch, "Disable hole punching")
serviceFlags.BoolVar(&config.OverrideDNS, "override-dns", config.OverrideDNS, "Override system DNS settings")
serviceFlags.BoolVar(&config.DisableRelay, "disable-relay", config.DisableRelay, "Disable relay connections")
serviceFlags.BoolVar(&config.TunnelDNS, "tunnel-dns", config.TunnelDNS, "Use tunnel for DNS traffic")
// serviceFlags.BoolVar(&config.DoNotCreateNewClient, "do-not-create-new-client", config.DoNotCreateNewClient, "Do not create new client")
version := serviceFlags.Bool("version", false, "Print the version")
@@ -393,6 +402,9 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
if config.DisableRelay != origValues["disableRelay"].(bool) {
config.sources["disableRelay"] = string(SourceCLI)
}
if config.TunnelDNS != origValues["tunnelDNS"].(bool) {
config.sources["tunnelDNS"] = string(SourceCLI)
}
// if config.DoNotCreateNewClient != origValues["doNotCreateNewClient"].(bool) {
// config.sources["doNotCreateNewClient"] = string(SourceCLI)
// }
@@ -606,6 +618,7 @@ func (c *OlmConfig) ShowConfig() {
fmt.Println("\nAdvanced:")
fmt.Printf(" disable-holepunch = %v [%s]\n", c.DisableHolepunch, getSource("disableHolepunch"))
fmt.Printf(" override-dns = %v [%s]\n", c.OverrideDNS, getSource("overrideDNS"))
fmt.Printf(" tunnel-dns = %v [%s]\n", c.TunnelDNS, getSource("tunnelDNS"))
fmt.Printf(" disable-relay = %v [%s]\n", c.DisableRelay, getSource("disableRelay"))
// fmt.Printf(" do-not-create-new-client = %v [%s]\n", c.DoNotCreateNewClient, getSource("doNotCreateNewClient"))
if c.TlsClientCert != "" {

View File

@@ -34,18 +34,26 @@ type DNSProxy struct {
ep *channel.Endpoint
proxyIP netip.Addr
upstreamDNS []string
tunnelDNS bool // Whether to tunnel DNS queries over WireGuard or to spit them out locally
mtu int
tunDevice tun.Device // Direct reference to underlying TUN device for responses
middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering
recordStore *DNSRecordStore // Local DNS records
// Tunnel DNS fields - for sending queries over WireGuard
tunnelIP netip.Addr // WireGuard interface IP (source for tunneled queries)
tunnelStack *stack.Stack // Separate netstack for outbound tunnel queries
tunnelEp *channel.Endpoint
tunnelActivePorts map[uint16]bool
tunnelPortsLock sync.Mutex
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
}
// NewDNSProxy creates a new DNS proxy
func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string) (*DNSProxy, error) {
func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string, tunnelDns bool, tunnelIP string) (*DNSProxy, error) {
proxyIP, err := PickIPFromSubnet(utilitySubnet)
if err != nil {
return nil, fmt.Errorf("failed to pick DNS proxy IP from subnet: %v", err)
@@ -58,17 +66,28 @@ func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu in
ctx, cancel := context.WithCancel(context.Background())
proxy := &DNSProxy{
proxyIP: proxyIP,
mtu: mtu,
tunDevice: tunDevice,
middleDevice: middleDevice,
upstreamDNS: upstreamDns,
recordStore: NewDNSRecordStore(),
ctx: ctx,
cancel: cancel,
proxyIP: proxyIP,
mtu: mtu,
tunDevice: tunDevice,
middleDevice: middleDevice,
upstreamDNS: upstreamDns,
tunnelDNS: tunnelDns,
recordStore: NewDNSRecordStore(),
tunnelActivePorts: make(map[uint16]bool),
ctx: ctx,
cancel: cancel,
}
// Create gvisor netstack
// Parse tunnel IP if provided (needed for tunneled DNS)
if tunnelIP != "" {
addr, err := netip.ParseAddr(tunnelIP)
if err != nil {
return nil, fmt.Errorf("failed to parse tunnel IP: %v", err)
}
proxy.tunnelIP = addr
}
// Create gvisor netstack for receiving DNS queries
stackOpts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
@@ -101,9 +120,104 @@ func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu in
NIC: 1,
})
// Initialize tunnel netstack if tunnel DNS is enabled
if tunnelDns {
if !proxy.tunnelIP.IsValid() {
return nil, fmt.Errorf("tunnel IP is required when tunnelDNS is enabled")
}
// TODO: DO WE NEED TO ESTABLISH ANOTHER NETSTACK HERE OR CAN WE COMBINE WITH WGTESTER?
if err := proxy.initTunnelNetstack(); err != nil {
return nil, fmt.Errorf("failed to initialize tunnel netstack: %v", err)
}
}
return proxy, nil
}
// initTunnelNetstack creates a separate netstack for outbound DNS queries through the tunnel
func (p *DNSProxy) initTunnelNetstack() error {
// Create gvisor netstack for outbound tunnel queries
stackOpts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
HandleLocal: true,
}
p.tunnelEp = channel.New(256, uint32(p.mtu), "")
p.tunnelStack = stack.New(stackOpts)
// Create NIC
if err := p.tunnelStack.CreateNIC(1, p.tunnelEp); err != nil {
return fmt.Errorf("failed to create tunnel NIC: %v", err)
}
// Add tunnel IP address (WireGuard interface IP)
ipBytes := p.tunnelIP.As4()
protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddrFrom4(ipBytes).WithPrefix(),
}
if err := p.tunnelStack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil {
return fmt.Errorf("failed to add tunnel protocol address: %v", err)
}
// Add default route
p.tunnelStack.AddRoute(tcpip.Route{
Destination: header.IPv4EmptySubnet,
NIC: 1,
})
// Register filter rule on MiddleDevice to intercept responses
p.middleDevice.AddRule(p.tunnelIP, p.handleTunnelResponse)
return nil
}
// handleTunnelResponse handles packets coming back from the tunnel destined for the tunnel IP
func (p *DNSProxy) handleTunnelResponse(packet []byte) bool {
// Check if it's UDP
proto, ok := util.GetProtocol(packet)
if !ok || proto != 17 { // UDP
return false
}
// Check destination port - should be one of our active outbound ports
port, ok := util.GetDestPort(packet)
if !ok {
return false
}
// Check if we are expecting a response on this port
p.tunnelPortsLock.Lock()
active := p.tunnelActivePorts[uint16(port)]
p.tunnelPortsLock.Unlock()
if !active {
return false
}
// Inject into tunnel netstack
version := packet[0] >> 4
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(packet),
})
switch version {
case 4:
p.tunnelEp.InjectInbound(ipv4.ProtocolNumber, pkb)
case 6:
p.tunnelEp.InjectInbound(ipv6.ProtocolNumber, pkb)
default:
pkb.DecRef()
return false
}
pkb.DecRef()
return true // Handled
}
// Start starts the DNS proxy and registers with the filter
func (p *DNSProxy) Start() error {
// Install packet filter rule
@@ -114,7 +228,13 @@ func (p *DNSProxy) Start() error {
go p.runDNSListener()
go p.runPacketSender()
logger.Info("DNS proxy started on %s:%d", p.proxyIP.String(), DNSPort)
// Start tunnel packet sender if tunnel DNS is enabled
if p.tunnelDNS {
p.wg.Add(1)
go p.runTunnelPacketSender()
}
logger.Info("DNS proxy started on %s:%d (tunnelDNS=%v)", p.proxyIP.String(), DNSPort, p.tunnelDNS)
return nil
}
@@ -122,6 +242,9 @@ func (p *DNSProxy) Start() error {
func (p *DNSProxy) Stop() {
if p.middleDevice != nil {
p.middleDevice.RemoveRule(p.proxyIP)
if p.tunnelDNS && p.tunnelIP.IsValid() {
p.middleDevice.RemoveRule(p.tunnelIP)
}
}
p.cancel()
@@ -130,12 +253,21 @@ func (p *DNSProxy) Stop() {
p.ep.Close()
}
// Close tunnel endpoint if it exists
if p.tunnelEp != nil {
p.tunnelEp.Close()
}
p.wg.Wait()
if p.stack != nil {
p.stack.Close()
}
if p.tunnelStack != nil {
p.tunnelStack.Close()
}
logger.Info("DNS proxy stopped")
}
@@ -348,8 +480,16 @@ func (p *DNSProxy) forwardToUpstream(query *dns.Msg) *dns.Msg {
return response
}
// queryUpstream sends a DNS query to upstream server using miekg/dns
// queryUpstream sends a DNS query to upstream server
func (p *DNSProxy) queryUpstream(server string, query *dns.Msg, timeout time.Duration) (*dns.Msg, error) {
if p.tunnelDNS {
return p.queryUpstreamTunnel(server, query, timeout)
}
return p.queryUpstreamDirect(server, query, timeout)
}
// queryUpstreamDirect sends a DNS query to upstream server using miekg/dns directly (host networking)
func (p *DNSProxy) queryUpstreamDirect(server string, query *dns.Msg, timeout time.Duration) (*dns.Msg, error) {
client := &dns.Client{
Timeout: timeout,
}
@@ -362,6 +502,155 @@ func (p *DNSProxy) queryUpstream(server string, query *dns.Msg, timeout time.Dur
return response, nil
}
// queryUpstreamTunnel sends a DNS query through the WireGuard tunnel
func (p *DNSProxy) queryUpstreamTunnel(server string, query *dns.Msg, timeout time.Duration) (*dns.Msg, error) {
// Dial through the tunnel netstack
conn, port, err := p.dialTunnel("udp", server)
if err != nil {
return nil, fmt.Errorf("failed to dial tunnel: %v", err)
}
defer func() {
conn.Close()
p.removeTunnelPort(port)
}()
// Pack the query
queryData, err := query.Pack()
if err != nil {
return nil, fmt.Errorf("failed to pack query: %v", err)
}
// Set deadline
conn.SetDeadline(time.Now().Add(timeout))
// Send the query
_, err = conn.Write(queryData)
if err != nil {
return nil, fmt.Errorf("failed to send query: %v", err)
}
// Read the response
buf := make([]byte, 4096)
n, err := conn.Read(buf)
if err != nil {
return nil, fmt.Errorf("failed to read response: %v", err)
}
// Parse the response
response := new(dns.Msg)
if err := response.Unpack(buf[:n]); err != nil {
return nil, fmt.Errorf("failed to unpack response: %v", err)
}
return response, nil
}
// dialTunnel creates a UDP connection through the tunnel netstack
func (p *DNSProxy) dialTunnel(network, addr string) (net.Conn, uint16, error) {
if p.tunnelStack == nil {
return nil, 0, fmt.Errorf("tunnel netstack not initialized")
}
// Parse remote address
raddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, 0, err
}
// Use tunnel IP as source
ipBytes := p.tunnelIP.As4()
// Create UDP connection with ephemeral port
laddr := &tcpip.FullAddress{
NIC: 1,
Addr: tcpip.AddrFrom4(ipBytes),
Port: 0,
}
raddrTcpip := &tcpip.FullAddress{
NIC: 1,
Addr: tcpip.AddrFrom4([4]byte(raddr.IP.To4())),
Port: uint16(raddr.Port),
}
conn, err := gonet.DialUDP(p.tunnelStack, laddr, raddrTcpip, ipv4.ProtocolNumber)
if err != nil {
return nil, 0, err
}
// Get local port
localAddr := conn.LocalAddr().(*net.UDPAddr)
port := uint16(localAddr.Port)
// Register port so we can receive responses
p.tunnelPortsLock.Lock()
p.tunnelActivePorts[port] = true
p.tunnelPortsLock.Unlock()
return conn, port, nil
}
// removeTunnelPort removes a port from the active ports map
func (p *DNSProxy) removeTunnelPort(port uint16) {
p.tunnelPortsLock.Lock()
delete(p.tunnelActivePorts, port)
p.tunnelPortsLock.Unlock()
}
// runTunnelPacketSender reads packets from tunnel netstack and injects them into WireGuard
func (p *DNSProxy) runTunnelPacketSender() {
defer p.wg.Done()
logger.Debug("DNS tunnel packet sender goroutine started")
ticker := time.NewTicker(1 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-p.ctx.Done():
logger.Debug("DNS tunnel packet sender exiting")
// Drain any remaining packets
for {
pkt := p.tunnelEp.Read()
if pkt == nil {
break
}
pkt.DecRef()
}
return
case <-ticker.C:
// Try to read packets
for i := 0; i < 10; i++ {
pkt := p.tunnelEp.Read()
if pkt == nil {
break
}
// Extract packet data
slices := pkt.AsSlices()
if len(slices) > 0 {
var totalSize int
for _, slice := range slices {
totalSize += len(slice)
}
buf := make([]byte, totalSize)
pos := 0
for _, slice := range slices {
copy(buf[pos:], slice)
pos += len(slice)
}
// Inject into MiddleDevice (outbound to WG)
p.middleDevice.InjectOutbound(buf)
}
pkt.DecRef()
}
}
}
}
// runPacketSender sends packets from netstack back to TUN
func (p *DNSProxy) runPacketSender() {
defer p.wg.Done()

View File

@@ -2,6 +2,7 @@ package dns
import (
"net"
"strings"
"sync"
"github.com/miekg/dns"
@@ -17,21 +18,26 @@ const (
// DNSRecordStore manages local DNS records for A and AAAA queries
type DNSRecordStore struct {
mu sync.RWMutex
aRecords map[string][]net.IP // domain -> list of IPv4 addresses
aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses
mu sync.RWMutex
aRecords map[string][]net.IP // domain -> list of IPv4 addresses
aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses
aWildcards map[string][]net.IP // wildcard pattern -> list of IPv4 addresses
aaaaWildcards map[string][]net.IP // wildcard pattern -> list of IPv6 addresses
}
// NewDNSRecordStore creates a new DNS record store
func NewDNSRecordStore() *DNSRecordStore {
return &DNSRecordStore{
aRecords: make(map[string][]net.IP),
aaaaRecords: make(map[string][]net.IP),
aRecords: make(map[string][]net.IP),
aaaaRecords: make(map[string][]net.IP),
aWildcards: make(map[string][]net.IP),
aaaaWildcards: make(map[string][]net.IP),
}
}
// AddRecord adds a DNS record mapping (A or AAAA)
// domain should be in FQDN format (e.g., "example.com.")
// domain can contain wildcards: * (0+ chars) and ? (exactly 1 char)
// ip should be a valid IPv4 or IPv6 address
func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
s.mu.Lock()
@@ -45,12 +51,23 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
// Normalize domain to lowercase
domain = dns.Fqdn(domain)
// Check if domain contains wildcards
isWildcard := strings.ContainsAny(domain, "*?")
if ip.To4() != nil {
// IPv4 address
s.aRecords[domain] = append(s.aRecords[domain], ip)
if isWildcard {
s.aWildcards[domain] = append(s.aWildcards[domain], ip)
} else {
s.aRecords[domain] = append(s.aRecords[domain], ip)
}
} else if ip.To16() != nil {
// IPv6 address
s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip)
if isWildcard {
s.aaaaWildcards[domain] = append(s.aaaaWildcards[domain], ip)
} else {
s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip)
}
} else {
return &net.ParseError{Type: "IP address", Text: ip.String()}
}
@@ -59,7 +76,7 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
}
// RemoveRecord removes a specific DNS record mapping
// If ip is nil, removes all records for the domain
// If ip is nil, removes all records for the domain (including wildcards)
func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -72,33 +89,60 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
// Normalize domain to lowercase
domain = dns.Fqdn(domain)
// Check if domain contains wildcards
isWildcard := strings.ContainsAny(domain, "*?")
if ip == nil {
// Remove all records for this domain
delete(s.aRecords, domain)
delete(s.aaaaRecords, domain)
if isWildcard {
delete(s.aWildcards, domain)
delete(s.aaaaWildcards, domain)
} else {
delete(s.aRecords, domain)
delete(s.aaaaRecords, domain)
}
return
}
if ip.To4() != nil {
// Remove specific IPv4 address
if ips, ok := s.aRecords[domain]; ok {
s.aRecords[domain] = removeIP(ips, ip)
if len(s.aRecords[domain]) == 0 {
delete(s.aRecords, domain)
if isWildcard {
if ips, ok := s.aWildcards[domain]; ok {
s.aWildcards[domain] = removeIP(ips, ip)
if len(s.aWildcards[domain]) == 0 {
delete(s.aWildcards, domain)
}
}
} else {
if ips, ok := s.aRecords[domain]; ok {
s.aRecords[domain] = removeIP(ips, ip)
if len(s.aRecords[domain]) == 0 {
delete(s.aRecords, domain)
}
}
}
} else if ip.To16() != nil {
// Remove specific IPv6 address
if ips, ok := s.aaaaRecords[domain]; ok {
s.aaaaRecords[domain] = removeIP(ips, ip)
if len(s.aaaaRecords[domain]) == 0 {
delete(s.aaaaRecords, domain)
if isWildcard {
if ips, ok := s.aaaaWildcards[domain]; ok {
s.aaaaWildcards[domain] = removeIP(ips, ip)
if len(s.aaaaWildcards[domain]) == 0 {
delete(s.aaaaWildcards, domain)
}
}
} else {
if ips, ok := s.aaaaRecords[domain]; ok {
s.aaaaRecords[domain] = removeIP(ips, ip)
if len(s.aaaaRecords[domain]) == 0 {
delete(s.aaaaRecords, domain)
}
}
}
}
}
// GetRecords returns all IP addresses for a domain and record type
// First checks for exact matches, then checks wildcard patterns
func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -109,16 +153,45 @@ func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.
var records []net.IP
switch recordType {
case RecordTypeA:
// Check exact match first
if ips, ok := s.aRecords[domain]; ok {
// Return a copy to prevent external modifications
records = make([]net.IP, len(ips))
copy(records, ips)
return records
}
// Check wildcard patterns
for pattern, ips := range s.aWildcards {
if matchWildcard(pattern, domain) {
records = append(records, ips...)
}
}
if len(records) > 0 {
// Return a copy
result := make([]net.IP, len(records))
copy(result, records)
return result
}
case RecordTypeAAAA:
// Check exact match first
if ips, ok := s.aaaaRecords[domain]; ok {
// Return a copy to prevent external modifications
records = make([]net.IP, len(ips))
copy(records, ips)
return records
}
// Check wildcard patterns
for pattern, ips := range s.aaaaWildcards {
if matchWildcard(pattern, domain) {
records = append(records, ips...)
}
}
if len(records) > 0 {
// Return a copy
result := make([]net.IP, len(records))
copy(result, records)
return result
}
}
@@ -126,6 +199,7 @@ func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.
}
// HasRecord checks if a domain has any records of the specified type
// Checks both exact matches and wildcard patterns
func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -135,11 +209,27 @@ func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
switch recordType {
case RecordTypeA:
_, ok := s.aRecords[domain]
return ok
// Check exact match
if _, ok := s.aRecords[domain]; ok {
return true
}
// Check wildcard patterns
for pattern := range s.aWildcards {
if matchWildcard(pattern, domain) {
return true
}
}
case RecordTypeAAAA:
_, ok := s.aaaaRecords[domain]
return ok
// Check exact match
if _, ok := s.aaaaRecords[domain]; ok {
return true
}
// Check wildcard patterns
for pattern := range s.aaaaWildcards {
if matchWildcard(pattern, domain) {
return true
}
}
}
return false
@@ -152,6 +242,8 @@ func (s *DNSRecordStore) Clear() {
s.aRecords = make(map[string][]net.IP)
s.aaaaRecords = make(map[string][]net.IP)
s.aWildcards = make(map[string][]net.IP)
s.aaaaWildcards = make(map[string][]net.IP)
}
// removeIP is a helper function to remove a specific IP from a slice
@@ -164,3 +256,70 @@ func removeIP(ips []net.IP, toRemove net.IP) []net.IP {
}
return result
}
// matchWildcard checks if a domain matches a wildcard pattern
// Pattern supports * (0+ chars) and ? (exactly 1 char)
// Special case: *.domain.com does not match domain.com itself
func matchWildcard(pattern, domain string) bool {
return matchWildcardInternal(pattern, domain, 0, 0)
}
// matchWildcardInternal performs the actual wildcard matching recursively
func matchWildcardInternal(pattern, domain string, pi, di int) bool {
plen := len(pattern)
dlen := len(domain)
// Base cases
if pi == plen && di == dlen {
return true
}
if pi == plen {
return false
}
// Handle wildcard characters
if pattern[pi] == '*' {
// Special case: if pattern starts with "*." and we're at the beginning,
// ensure we don't match the domain without a prefix
// e.g., *.autoco.internal should not match autoco.internal
if pi == 0 && pi+1 < plen && pattern[pi+1] == '.' {
// The * must match at least one character
if di == dlen {
return false
}
// Try matching 1 or more characters before the dot
for i := di + 1; i <= dlen; i++ {
if matchWildcardInternal(pattern, domain, pi+1, i) {
return true
}
}
return false
}
// Normal * matching (0 or more characters)
// Try matching 0 characters (skip the *)
if matchWildcardInternal(pattern, domain, pi+1, di) {
return true
}
// Try matching 1+ characters
if di < dlen {
return matchWildcardInternal(pattern, domain, pi, di+1)
}
return false
}
if pattern[pi] == '?' {
// ? matches exactly one character
if di >= dlen {
return false
}
return matchWildcardInternal(pattern, domain, pi+1, di+1)
}
// Regular character - must match exactly
if di >= dlen || pattern[pi] != domain[di] {
return false
}
return matchWildcardInternal(pattern, domain, pi+1, di+1)
}

350
dns/dns_records_test.go Normal file
View File

@@ -0,0 +1,350 @@
package dns
import (
"net"
"testing"
)
func TestWildcardMatching(t *testing.T) {
tests := []struct {
name string
pattern string
domain string
expected bool
}{
// Basic wildcard tests
{
name: "*.autoco.internal matches host.autoco.internal",
pattern: "*.autoco.internal.",
domain: "host.autoco.internal.",
expected: true,
},
{
name: "*.autoco.internal matches longerhost.autoco.internal",
pattern: "*.autoco.internal.",
domain: "longerhost.autoco.internal.",
expected: true,
},
{
name: "*.autoco.internal matches sub.host.autoco.internal",
pattern: "*.autoco.internal.",
domain: "sub.host.autoco.internal.",
expected: true,
},
{
name: "*.autoco.internal does NOT match autoco.internal",
pattern: "*.autoco.internal.",
domain: "autoco.internal.",
expected: false,
},
// Question mark wildcard tests
{
name: "host-0?.autoco.internal matches host-01.autoco.internal",
pattern: "host-0?.autoco.internal.",
domain: "host-01.autoco.internal.",
expected: true,
},
{
name: "host-0?.autoco.internal matches host-0a.autoco.internal",
pattern: "host-0?.autoco.internal.",
domain: "host-0a.autoco.internal.",
expected: true,
},
{
name: "host-0?.autoco.internal does NOT match host-0.autoco.internal",
pattern: "host-0?.autoco.internal.",
domain: "host-0.autoco.internal.",
expected: false,
},
{
name: "host-0?.autoco.internal does NOT match host-012.autoco.internal",
pattern: "host-0?.autoco.internal.",
domain: "host-012.autoco.internal.",
expected: false,
},
// Combined wildcard tests
{
name: "*.host-0?.autoco.internal matches sub.host-01.autoco.internal",
pattern: "*.host-0?.autoco.internal.",
domain: "sub.host-01.autoco.internal.",
expected: true,
},
{
name: "*.host-0?.autoco.internal matches prefix.host-0a.autoco.internal",
pattern: "*.host-0?.autoco.internal.",
domain: "prefix.host-0a.autoco.internal.",
expected: true,
},
{
name: "*.host-0?.autoco.internal does NOT match host-01.autoco.internal",
pattern: "*.host-0?.autoco.internal.",
domain: "host-01.autoco.internal.",
expected: false,
},
// Multiple asterisks
{
name: "*.*. autoco.internal matches any.thing.autoco.internal",
pattern: "*.*.autoco.internal.",
domain: "any.thing.autoco.internal.",
expected: true,
},
{
name: "*.*.autoco.internal does NOT match single.autoco.internal",
pattern: "*.*.autoco.internal.",
domain: "single.autoco.internal.",
expected: false,
},
// Asterisk in middle
{
name: "host-*.autoco.internal matches host-anything.autoco.internal",
pattern: "host-*.autoco.internal.",
domain: "host-anything.autoco.internal.",
expected: true,
},
{
name: "host-*.autoco.internal matches host-.autoco.internal (empty match)",
pattern: "host-*.autoco.internal.",
domain: "host-.autoco.internal.",
expected: true,
},
// Multiple question marks
{
name: "host-??.autoco.internal matches host-01.autoco.internal",
pattern: "host-??.autoco.internal.",
domain: "host-01.autoco.internal.",
expected: true,
},
{
name: "host-??.autoco.internal does NOT match host-1.autoco.internal",
pattern: "host-??.autoco.internal.",
domain: "host-1.autoco.internal.",
expected: false,
},
// Exact match (no wildcards)
{
name: "exact.autoco.internal matches exact.autoco.internal",
pattern: "exact.autoco.internal.",
domain: "exact.autoco.internal.",
expected: true,
},
{
name: "exact.autoco.internal does NOT match other.autoco.internal",
pattern: "exact.autoco.internal.",
domain: "other.autoco.internal.",
expected: false,
},
// Edge cases
{
name: "* matches anything",
pattern: "*",
domain: "anything.at.all.",
expected: true,
},
{
name: "*.* matches multi.level.",
pattern: "*.*",
domain: "multi.level.",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := matchWildcard(tt.pattern, tt.domain)
if result != tt.expected {
t.Errorf("matchWildcard(%q, %q) = %v, want %v", tt.pattern, tt.domain, result, tt.expected)
}
})
}
}
func TestDNSRecordStoreWildcard(t *testing.T) {
store := NewDNSRecordStore()
// Add wildcard records
wildcardIP := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.autoco.internal", wildcardIP)
if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err)
}
// Add exact record
exactIP := net.ParseIP("10.0.0.2")
err = store.AddRecord("exact.autoco.internal", exactIP)
if err != nil {
t.Fatalf("Failed to add exact record: %v", err)
}
// Test exact match takes precedence
ips := store.GetRecords("exact.autoco.internal.", RecordTypeA)
if len(ips) != 1 {
t.Errorf("Expected 1 IP for exact match, got %d", len(ips))
}
if !ips[0].Equal(exactIP) {
t.Errorf("Expected exact IP %v, got %v", exactIP, ips[0])
}
// Test wildcard match
ips = store.GetRecords("host.autoco.internal.", RecordTypeA)
if len(ips) != 1 {
t.Errorf("Expected 1 IP for wildcard match, got %d", len(ips))
}
if !ips[0].Equal(wildcardIP) {
t.Errorf("Expected wildcard IP %v, got %v", wildcardIP, ips[0])
}
// Test non-match (base domain)
ips = store.GetRecords("autoco.internal.", RecordTypeA)
if len(ips) != 0 {
t.Errorf("Expected 0 IPs for base domain, got %d", len(ips))
}
}
func TestDNSRecordStoreComplexWildcard(t *testing.T) {
store := NewDNSRecordStore()
// Add complex wildcard pattern
ip1 := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.host-0?.autoco.internal", ip1)
if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err)
}
// Test matching domain
ips := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA)
if len(ips) != 1 {
t.Errorf("Expected 1 IP for complex wildcard match, got %d", len(ips))
}
if len(ips) > 0 && !ips[0].Equal(ip1) {
t.Errorf("Expected IP %v, got %v", ip1, ips[0])
}
// Test non-matching domain (missing prefix)
ips = store.GetRecords("host-01.autoco.internal.", RecordTypeA)
if len(ips) != 0 {
t.Errorf("Expected 0 IPs for domain without prefix, got %d", len(ips))
}
// Test non-matching domain (wrong ? position)
ips = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA)
if len(ips) != 0 {
t.Errorf("Expected 0 IPs for domain with wrong ? match, got %d", len(ips))
}
}
func TestDNSRecordStoreRemoveWildcard(t *testing.T) {
store := NewDNSRecordStore()
// Add wildcard record
ip := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.autoco.internal", ip)
if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err)
}
// Verify it exists
ips := store.GetRecords("host.autoco.internal.", RecordTypeA)
if len(ips) != 1 {
t.Errorf("Expected 1 IP before removal, got %d", len(ips))
}
// Remove wildcard record
store.RemoveRecord("*.autoco.internal", nil)
// Verify it's gone
ips = store.GetRecords("host.autoco.internal.", RecordTypeA)
if len(ips) != 0 {
t.Errorf("Expected 0 IPs after removal, got %d", len(ips))
}
}
func TestDNSRecordStoreMultipleWildcards(t *testing.T) {
store := NewDNSRecordStore()
// Add multiple wildcard patterns that don't overlap
ip1 := net.ParseIP("10.0.0.1")
ip2 := net.ParseIP("10.0.0.2")
ip3 := net.ParseIP("10.0.0.3")
err := store.AddRecord("*.prod.autoco.internal", ip1)
if err != nil {
t.Fatalf("Failed to add first wildcard: %v", err)
}
err = store.AddRecord("*.dev.autoco.internal", ip2)
if err != nil {
t.Fatalf("Failed to add second wildcard: %v", err)
}
// Add a broader wildcard that matches both
err = store.AddRecord("*.autoco.internal", ip3)
if err != nil {
t.Fatalf("Failed to add third wildcard: %v", err)
}
// Test domain matching only the prod pattern and the broad pattern
ips := store.GetRecords("host.prod.autoco.internal.", RecordTypeA)
if len(ips) != 2 {
t.Errorf("Expected 2 IPs (prod + broad), got %d", len(ips))
}
// Test domain matching only the dev pattern and the broad pattern
ips = store.GetRecords("service.dev.autoco.internal.", RecordTypeA)
if len(ips) != 2 {
t.Errorf("Expected 2 IPs (dev + broad), got %d", len(ips))
}
// Test domain matching only the broad pattern
ips = store.GetRecords("host.test.autoco.internal.", RecordTypeA)
if len(ips) != 1 {
t.Errorf("Expected 1 IP (broad only), got %d", len(ips))
}
}
func TestDNSRecordStoreIPv6Wildcard(t *testing.T) {
store := NewDNSRecordStore()
// Add IPv6 wildcard record
ip := net.ParseIP("2001:db8::1")
err := store.AddRecord("*.autoco.internal", ip)
if err != nil {
t.Fatalf("Failed to add IPv6 wildcard record: %v", err)
}
// Test wildcard match for IPv6
ips := store.GetRecords("host.autoco.internal.", RecordTypeAAAA)
if len(ips) != 1 {
t.Errorf("Expected 1 IPv6 for wildcard match, got %d", len(ips))
}
if len(ips) > 0 && !ips[0].Equal(ip) {
t.Errorf("Expected IPv6 %v, got %v", ip, ips[0])
}
}
func TestHasRecordWildcard(t *testing.T) {
store := NewDNSRecordStore()
// Add wildcard record
ip := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.autoco.internal", ip)
if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err)
}
// Test HasRecord with wildcard match
if !store.HasRecord("host.autoco.internal.", RecordTypeA) {
t.Error("Expected HasRecord to return true for wildcard match")
}
// Test HasRecord with non-match
if store.HasRecord("autoco.internal.", RecordTypeA) {
t.Error("Expected HasRecord to return false for base domain")
}
}

View File

@@ -5,9 +5,13 @@ package dns
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"net/netip"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"
@@ -28,19 +32,38 @@ const (
keyServerPort = "ServerPort"
arraySymbol = "* "
digitSymbol = "# "
// State file name for crash recovery
dnsStateFileName = "dns_state.json"
)
// DNSPersistentState represents the state saved to disk for crash recovery
type DNSPersistentState struct {
CreatedKeys []string `json:"created_keys"`
}
// DarwinDNSConfigurator manages DNS settings on macOS using scutil
type DarwinDNSConfigurator struct {
createdKeys map[string]struct{}
originalState *DNSState
stateFilePath string
}
// NewDarwinDNSConfigurator creates a new macOS DNS configurator
func NewDarwinDNSConfigurator() (*DarwinDNSConfigurator, error) {
return &DarwinDNSConfigurator{
createdKeys: make(map[string]struct{}),
}, nil
stateFilePath := getDNSStateFilePath()
configurator := &DarwinDNSConfigurator{
createdKeys: make(map[string]struct{}),
stateFilePath: stateFilePath,
}
// Clean up any leftover state from a previous crash
if err := configurator.CleanupUncleanShutdown(); err != nil {
logger.Warn("Failed to cleanup previous DNS state: %v", err)
}
return configurator, nil
}
// Name returns the configurator name
@@ -67,6 +90,11 @@ func (d *DarwinDNSConfigurator) SetDNS(servers []netip.Addr) ([]netip.Addr, erro
return nil, fmt.Errorf("apply DNS servers: %w", err)
}
// Persist state to disk for crash recovery
if err := d.saveState(); err != nil {
logger.Warn("Failed to save DNS state for crash recovery: %v", err)
}
// Flush DNS cache
if err := d.flushDNSCache(); err != nil {
// Non-fatal, just log
@@ -85,6 +113,11 @@ func (d *DarwinDNSConfigurator) RestoreDNS() error {
}
}
// Clear state file after successful restoration
if err := d.clearState(); err != nil {
logger.Warn("Failed to clear DNS state file: %v", err)
}
// Flush DNS cache
if err := d.flushDNSCache(); err != nil {
fmt.Printf("warning: failed to flush DNS cache: %v\n", err)
@@ -112,6 +145,47 @@ func (d *DarwinDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
return servers, nil
}
// CleanupUncleanShutdown removes any DNS keys left over from a previous crash
func (d *DarwinDNSConfigurator) CleanupUncleanShutdown() error {
state, err := d.loadState()
if err != nil {
if os.IsNotExist(err) {
// No state file, nothing to clean up
return nil
}
return fmt.Errorf("load state: %w", err)
}
if len(state.CreatedKeys) == 0 {
// No keys to clean up
return nil
}
logger.Info("Found DNS state from previous session, cleaning up %d keys", len(state.CreatedKeys))
// Remove all keys from previous session
var lastErr error
for _, key := range state.CreatedKeys {
logger.Debug("Removing leftover DNS key: %s", key)
if err := d.removeKeyDirect(key); err != nil {
logger.Warn("Failed to remove DNS key %s: %v", key, err)
lastErr = err
}
}
// Clear state file
if err := d.clearState(); err != nil {
logger.Warn("Failed to clear DNS state file: %v", err)
}
// Flush DNS cache after cleanup
if err := d.flushDNSCache(); err != nil {
logger.Warn("Failed to flush DNS cache after cleanup: %v", err)
}
return lastErr
}
// applyDNSServers applies the DNS server configuration
func (d *DarwinDNSConfigurator) applyDNSServers(servers []netip.Addr) error {
if len(servers) == 0 {
@@ -156,15 +230,25 @@ func (d *DarwinDNSConfigurator) addDNSState(state, domains string, dnsServer net
return nil
}
// removeKey removes a DNS configuration key
// removeKey removes a DNS configuration key and updates internal state
func (d *DarwinDNSConfigurator) removeKey(key string) error {
if err := d.removeKeyDirect(key); err != nil {
return err
}
delete(d.createdKeys, key)
return nil
}
// removeKeyDirect removes a DNS configuration key without updating internal state
// Used for cleanup operations
func (d *DarwinDNSConfigurator) removeKeyDirect(key string) error {
cmd := fmt.Sprintf("remove %s\n", key)
if _, err := d.runScutil(cmd); err != nil {
return fmt.Errorf("remove key: %w", err)
}
delete(d.createdKeys, key)
return nil
}
@@ -266,3 +350,70 @@ func (d *DarwinDNSConfigurator) runScutil(commands string) ([]byte, error) {
return output, nil
}
// getDNSStateFilePath returns the path to the DNS state file
func getDNSStateFilePath() string {
var stateDir string
switch runtime.GOOS {
case "darwin":
stateDir = filepath.Join(os.Getenv("HOME"), "Library", "Application Support", "olm-client")
default:
stateDir = filepath.Join(os.Getenv("HOME"), ".config", "olm-client")
}
if err := os.MkdirAll(stateDir, 0755); err != nil {
logger.Warn("Failed to create state directory: %v", err)
}
return filepath.Join(stateDir, dnsStateFileName)
}
// saveState persists the current DNS state to disk
func (d *DarwinDNSConfigurator) saveState() error {
keys := make([]string, 0, len(d.createdKeys))
for key := range d.createdKeys {
keys = append(keys, key)
}
state := DNSPersistentState{
CreatedKeys: keys,
}
data, err := json.MarshalIndent(state, "", " ")
if err != nil {
return fmt.Errorf("marshal state: %w", err)
}
if err := os.WriteFile(d.stateFilePath, data, 0644); err != nil {
return fmt.Errorf("write state file: %w", err)
}
logger.Debug("Saved DNS state to %s", d.stateFilePath)
return nil
}
// loadState loads the DNS state from disk
func (d *DarwinDNSConfigurator) loadState() (*DNSPersistentState, error) {
data, err := os.ReadFile(d.stateFilePath)
if err != nil {
return nil, err
}
var state DNSPersistentState
if err := json.Unmarshal(data, &state); err != nil {
return nil, fmt.Errorf("unmarshal state: %w", err)
}
return &state, nil
}
// clearState removes the DNS state file
func (d *DarwinDNSConfigurator) clearState() error {
err := os.Remove(d.stateFilePath)
if err != nil && !os.IsNotExist(err) {
return fmt.Errorf("remove state file: %w", err)
}
logger.Debug("Cleared DNS state file")
return nil
}

View File

@@ -22,7 +22,11 @@ type FileDNSConfigurator struct {
// NewFileDNSConfigurator creates a new file-based DNS configurator
func NewFileDNSConfigurator() (*FileDNSConfigurator, error) {
return &FileDNSConfigurator{}, nil
f := &FileDNSConfigurator{}
if err := f.CleanupUncleanShutdown(); err != nil {
return nil, fmt.Errorf("cleanup unclean shutdown: %w", err)
}
return f, nil
}
// Name returns the configurator name
@@ -78,6 +82,30 @@ func (f *FileDNSConfigurator) RestoreDNS() error {
return nil
}
// CleanupUncleanShutdown removes any DNS configuration left over from a previous crash
// For the file-based configurator, we check if a backup file exists (indicating a crash
// happened while DNS was configured) and restore from it if so.
func (f *FileDNSConfigurator) CleanupUncleanShutdown() error {
// Check if backup file exists from a previous session
if !f.isBackupExists() {
// No backup file, nothing to clean up
return nil
}
// A backup exists, which means we crashed while DNS was configured
// Restore the original resolv.conf
if err := copyFile(resolvConfBackupPath, resolvConfPath); err != nil {
return fmt.Errorf("restore from backup during cleanup: %w", err)
}
// Remove backup file
if err := os.Remove(resolvConfBackupPath); err != nil {
return fmt.Errorf("remove backup file during cleanup: %w", err)
}
return nil
}
// GetCurrentDNS returns the currently configured DNS servers
func (f *FileDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
content, err := os.ReadFile(resolvConfPath)

View File

@@ -50,11 +50,18 @@ func NewNetworkManagerDNSConfigurator(ifaceName string) (*NetworkManagerDNSConfi
return nil, fmt.Errorf("NetworkManager conf.d directory not found: %s", networkManagerConfDir)
}
return &NetworkManagerDNSConfigurator{
configurator := &NetworkManagerDNSConfigurator{
ifaceName: ifaceName,
confPath: networkManagerConfDir + "/" + networkManagerDNSConfFile,
dispatchPath: networkManagerDispatcherDir + "/" + networkManagerDispatcherFile,
}, nil
}
// Clean up any stale configuration from a previous unclean shutdown
if err := configurator.CleanupUncleanShutdown(); err != nil {
return nil, fmt.Errorf("cleanup unclean shutdown: %w", err)
}
return configurator, nil
}
// Name returns the configurator name
@@ -100,6 +107,30 @@ func (n *NetworkManagerDNSConfigurator) RestoreDNS() error {
return nil
}
// CleanupUncleanShutdown removes any DNS configuration left over from a previous crash
// For NetworkManager, we check if our config file exists and remove it if so.
// This ensures that if the process crashed while DNS was configured, the stale
// configuration is removed on the next startup.
func (n *NetworkManagerDNSConfigurator) CleanupUncleanShutdown() error {
// Check if our config file exists from a previous session
if _, err := os.Stat(n.confPath); os.IsNotExist(err) {
// No config file, nothing to clean up
return nil
}
// Remove the stale configuration file
if err := os.Remove(n.confPath); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("remove stale DNS config file: %w", err)
}
// Reload NetworkManager to apply the change
if err := n.reloadNetworkManager(); err != nil {
return fmt.Errorf("reload NetworkManager after cleanup: %w", err)
}
return nil
}
// GetCurrentDNS returns the currently configured DNS servers by reading /etc/resolv.conf
func (n *NetworkManagerDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
content, err := os.ReadFile("/etc/resolv.conf")

View File

@@ -31,10 +31,17 @@ func NewResolvconfDNSConfigurator(ifaceName string) (*ResolvconfDNSConfigurator,
return nil, fmt.Errorf("detect resolvconf type: %w", err)
}
return &ResolvconfDNSConfigurator{
configurator := &ResolvconfDNSConfigurator{
ifaceName: ifaceName,
implType: implType,
}, nil
}
// Call cleanup function to remove any stale DNS config for this interface
if err := configurator.CleanupUncleanShutdown(); err != nil {
return nil, fmt.Errorf("cleanup unclean shutdown: %w", err)
}
return configurator, nil
}
// Name returns the configurator name
@@ -84,6 +91,28 @@ func (r *ResolvconfDNSConfigurator) RestoreDNS() error {
return nil
}
// CleanupUncleanShutdown removes any DNS configuration left over from a previous crash
// For resolvconf, we attempt to delete any entry for the interface name.
// This ensures that if the process crashed while DNS was configured, the stale
// entry is removed on the next startup.
func (r *ResolvconfDNSConfigurator) CleanupUncleanShutdown() error {
// Try to delete any existing entry for this interface
// This is idempotent - if no entry exists, resolvconf will just return success
var cmd *exec.Cmd
switch r.implType {
case "openresolv":
cmd = exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName)
default:
cmd = exec.Command(resolvconfCommand, "-d", r.ifaceName)
}
// Ignore errors - the entry may not exist, which is fine
_ = cmd.Run()
return nil
}
// GetCurrentDNS returns the currently configured DNS servers
func (r *ResolvconfDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
// resolvconf doesn't provide a direct way to query per-interface DNS

View File

@@ -73,10 +73,17 @@ func NewSystemdResolvedDNSConfigurator(ifaceName string) (*SystemdResolvedDNSCon
return nil, fmt.Errorf("get link: %w", err)
}
return &SystemdResolvedDNSConfigurator{
config := &SystemdResolvedDNSConfigurator{
ifaceName: ifaceName,
dbusLinkObject: dbus.ObjectPath(linkPath),
}, nil
}
// Call cleanup function here
if err := config.CleanupUncleanShutdown(); err != nil {
fmt.Printf("warning: cleanup unclean shutdown failed: %v\n", err)
}
return config, nil
}
// Name returns the configurator name
@@ -133,6 +140,17 @@ func (s *SystemdResolvedDNSConfigurator) RestoreDNS() error {
return nil
}
// CleanupUncleanShutdown removes any DNS configuration left over from a previous crash
// For systemd-resolved, the DNS configuration is tied to the network interface.
// When the interface is destroyed and recreated, systemd-resolved automatically
// clears the per-link DNS settings, so there's nothing to clean up.
func (s *SystemdResolvedDNSConfigurator) CleanupUncleanShutdown() error {
// systemd-resolved DNS configuration is per-link and automatically cleared
// when the link (interface) is destroyed. Since the WireGuard interface is
// recreated on restart, there's no leftover state to clean up.
return nil
}
// GetCurrentDNS returns the currently configured DNS servers
// Note: systemd-resolved doesn't easily expose current per-link DNS servers via D-Bus
// This is a placeholder that returns an empty list

View File

@@ -17,6 +17,10 @@ type DNSConfigurator interface {
// Name returns the name of this configurator implementation
Name() string
// CleanupUncleanShutdown removes any DNS configuration left over from
// a previous crash or unclean shutdown. This should be called on startup.
CleanupUncleanShutdown() error
}
// DNSConfig contains the configuration for DNS override

View File

@@ -113,6 +113,18 @@ func (w *WindowsDNSConfigurator) RestoreDNS() error {
return nil
}
// CleanupUncleanShutdown removes any DNS configuration left over from a previous crash
// On Windows, we rely on the registry-based approach which doesn't leave orphaned state
// in the same way as macOS scutil. The DNS settings are tied to the interface which
// gets recreated on restart.
func (w *WindowsDNSConfigurator) CleanupUncleanShutdown() error {
// Windows DNS configuration via registry is interface-specific.
// When the WireGuard interface is recreated, it gets a new GUID,
// so there's no leftover state to clean up from previous sessions.
// The old interface's registry keys are effectively orphaned but harmless.
return nil
}
// GetCurrentDNS returns the currently configured DNS servers
func (w *WindowsDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
regKey, err := w.getInterfaceRegistryKey(registry.QUERY_VALUE)

2
go.mod
View File

@@ -4,7 +4,7 @@ go 1.25
require (
github.com/Microsoft/go-winio v0.6.2
github.com/fosrl/newt v0.0.0-20251208171729-6d7985689552
github.com/fosrl/newt v0.0.0-20251222020104-a21a8e90fa01
github.com/godbus/dbus/v5 v5.2.0
github.com/gorilla/websocket v1.5.3
github.com/miekg/dns v1.1.68

4
go.sum
View File

@@ -1,7 +1,7 @@
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/fosrl/newt v0.0.0-20251208171729-6d7985689552 h1:51pHUtoqQhYPS9OiBDHLgYV44X/CBzR5J7GuWO3izhU=
github.com/fosrl/newt v0.0.0-20251208171729-6d7985689552/go.mod h1:pol958CEs0nQmo/35Ltv0CGksheIKCS2hoNvdTVLEcI=
github.com/fosrl/newt v0.0.0-20251222020104-a21a8e90fa01 h1:VpuI42l4enih//6IFFQDln/B7WukfMePxIRIpXsNe/0=
github.com/fosrl/newt v0.0.0-20251222020104-a21a8e90fa01/go.mod h1:pol958CEs0nQmo/35Ltv0CGksheIKCS2hoNvdTVLEcI=
github.com/godbus/dbus/v5 v5.2.0 h1:3WexO+U+yg9T70v9FdHr9kCxYlazaAXUhx2VMkbfax8=
github.com/godbus/dbus/v5 v5.2.0/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=

BIN
http.pcap

Binary file not shown.

View File

@@ -374,8 +374,14 @@ func StartTunnel(config TunnelConfig) {
logger.Error("Failed to bring up WireGuard device: %v", err)
}
// Extract interface IP (strip CIDR notation if present)
interfaceIP := wgData.TunnelIP
if strings.Contains(interfaceIP, "/") {
interfaceIP = strings.Split(interfaceIP, "/")[0]
}
// Create and start DNS proxy
dnsProxy, err = dns.NewDNSProxy(tdev, middleDev, config.MTU, wgData.UtilitySubnet, config.UpstreamDNS)
dnsProxy, err = dns.NewDNSProxy(tdev, middleDev, config.MTU, wgData.UtilitySubnet, config.UpstreamDNS, config.TunnelDNS, interfaceIP)
if err != nil {
logger.Error("Failed to create DNS proxy: %v", err)
}
@@ -388,12 +394,6 @@ func StartTunnel(config TunnelConfig) {
logger.Error("Failed to add route for utility subnet: %v", err)
}
// TODO: seperate adding the callback to this so we can init it above with the interface
interfaceIP := wgData.TunnelIP
if strings.Contains(interfaceIP, "/") {
interfaceIP = strings.Split(interfaceIP, "/")[0]
}
// Create peer manager with integrated peer monitoring
peerManager = peers.NewPeerManager(peers.PeerManagerConfig{
Device: dev,
@@ -566,6 +566,14 @@ func StartTunnel(config TunnelConfig) {
return
}
// Remove any exit nodes associated with this peer from hole punching
if holePunchManager != nil {
removed := holePunchManager.RemoveExitNodesByPeer(removeData.SiteId)
if removed > 0 {
logger.Info("Removed %d exit nodes associated with peer %d from hole punch rotation", removed, removeData.SiteId)
}
}
// Remove successful
logger.Info("Successfully removed peer for site %d", removeData.SiteId)
})
@@ -727,7 +735,7 @@ func StartTunnel(config TunnelConfig) {
// Update HTTP server to mark this peer as using relay
apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.RelayEndpoint, true)
peerManager.RelayPeer(relayData.SiteId, primaryRelay)
peerManager.RelayPeer(relayData.SiteId, primaryRelay, relayData.RelayPort)
})
olm.RegisterHandler("olm/wg/peer/unrelay", func(msg websocket.WSMessage) {
@@ -777,6 +785,7 @@ func StartTunnel(config TunnelConfig) {
ExitNode struct {
PublicKey string `json:"publicKey"`
Endpoint string `json:"endpoint"`
RelayPort uint16 `json:"relayPort"`
} `json:"exitNode"`
}
@@ -792,9 +801,17 @@ func StartTunnel(config TunnelConfig) {
return
}
relayPort := handshakeData.ExitNode.RelayPort
if relayPort == 0 {
relayPort = 21820 // default relay port
}
siteId := handshakeData.SiteId
exitNode := holepunch.ExitNode{
Endpoint: handshakeData.ExitNode.Endpoint,
RelayPort: relayPort,
PublicKey: handshakeData.ExitNode.PublicKey,
SiteIds: []int{siteId},
}
added := holePunchManager.AddExitNode(exitNode)
@@ -878,9 +895,16 @@ func StartTunnel(config TunnelConfig) {
// Convert websocket.ExitNode to holepunch.ExitNode
hpExitNodes := make([]holepunch.ExitNode, len(exitNodes))
for i, node := range exitNodes {
relayPort := node.RelayPort
if relayPort == 0 {
relayPort = 21820 // default relay port
}
hpExitNodes[i] = holepunch.ExitNode{
Endpoint: node.Endpoint,
RelayPort: relayPort,
PublicKey: node.PublicKey,
SiteIds: node.SiteIds,
}
}

View File

@@ -61,6 +61,7 @@ type TunnelConfig struct {
EnableUAPI bool
OverrideDNS bool
TunnelDNS bool
DisableRelay bool
}

View File

@@ -743,7 +743,7 @@ func (pm *PeerManager) RemoveAlias(siteId int, aliasName string) error {
}
// RelayPeer handles failover to the relay server when a peer is disconnected
func (pm *PeerManager) RelayPeer(siteId int, relayEndpoint string) {
func (pm *PeerManager) RelayPeer(siteId int, relayEndpoint string, relayPort uint16) {
pm.mu.Lock()
peer, exists := pm.peers[siteId]
if exists {
@@ -764,10 +764,14 @@ func (pm *PeerManager) RelayPeer(siteId int, relayEndpoint string) {
formattedEndpoint = fmt.Sprintf("[%s]", relayEndpoint)
}
if relayPort == 0 {
relayPort = 21820 // fall back to 21820 for backward compatibility
}
// Update only the endpoint for this peer (update_only preserves other settings)
wgConfig := fmt.Sprintf(`public_key=%s
update_only=true
endpoint=%s:21820`, util.FixKey(peer.PublicKey), formattedEndpoint)
endpoint=%s:%d`, util.FixKey(peer.PublicKey), formattedEndpoint, relayPort)
err := pm.device.IpcSet(wgConfig)
if err != nil {

View File

@@ -505,7 +505,7 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() {
pm.mutex.Unlock()
for siteID, endpoint := range endpoints {
logger.Debug("Testing holepunch endpoint for site %d: %s", siteID, endpoint)
// logger.Debug("Testing holepunch endpoint for site %d: %s", siteID, endpoint)
result := pm.holepunchTester.TestEndpoint(endpoint, timeout)
pm.mutex.Lock()

View File

@@ -33,6 +33,7 @@ type PeerRemove struct {
type RelayPeerData struct {
SiteId int `json:"siteId"`
RelayEndpoint string `json:"relayEndpoint"`
RelayPort uint16 `json:"relayPort"`
}
type UnRelayPeerData struct {

View File

@@ -48,7 +48,9 @@ type TokenResponse struct {
type ExitNode struct {
Endpoint string `json:"endpoint"`
RelayPort uint16 `json:"relayPort"`
PublicKey string `json:"publicKey"`
SiteIds []int `json:"siteIds"`
}
type WSMessage struct {