From 6c7ee31330d50c0424dc5f2dd15319d27ce011e0 Mon Sep 17 00:00:00 2001 From: Owen Date: Sun, 23 Nov 2025 15:57:35 -0500 Subject: [PATCH] Working on sending down the dns Former-commit-id: 1a8385c45790a5924519025a83081dd1a4da4939 --- api/api.go | 26 ++++++++++--------- config.go | 65 +++++++++++++++++++++++++++++++++++++++++++++--- dns/dns_proxy.go | 56 ++++++++++++++++++++++------------------- main.go | 2 ++ olm/olm.go | 54 ++++++++++++++++++++-------------------- olm/types.go | 28 ++++++--------------- 6 files changed, 143 insertions(+), 88 deletions(-) diff --git a/api/api.go b/api/api.go index b8c848e..cf04a89 100644 --- a/api/api.go +++ b/api/api.go @@ -13,18 +13,20 @@ import ( // ConnectionRequest defines the structure for an incoming connection request type ConnectionRequest struct { - ID string `json:"id"` - Secret string `json:"secret"` - Endpoint string `json:"endpoint"` - UserToken string `json:"userToken,omitempty"` - MTU int `json:"mtu,omitempty"` - DNS string `json:"dns,omitempty"` - InterfaceName string `json:"interfaceName,omitempty"` - Holepunch bool `json:"holepunch,omitempty"` - TlsClientCert string `json:"tlsClientCert,omitempty"` - PingInterval string `json:"pingInterval,omitempty"` - PingTimeout string `json:"pingTimeout,omitempty"` - OrgID string `json:"orgId,omitempty"` + ID string `json:"id"` + Secret string `json:"secret"` + Endpoint string `json:"endpoint"` + UserToken string `json:"userToken,omitempty"` + MTU int `json:"mtu,omitempty"` + DNS string `json:"dns,omitempty"` + DNSProxyIP string `json:"dnsProxyIP,omitempty"` + UpstreamDNS []string `json:"upstreamDNS,omitempty"` + InterfaceName string `json:"interfaceName,omitempty"` + Holepunch bool `json:"holepunch,omitempty"` + TlsClientCert string `json:"tlsClientCert,omitempty"` + PingInterval string `json:"pingInterval,omitempty"` + PingTimeout string `json:"pingTimeout,omitempty"` + OrgID string `json:"orgId,omitempty"` } // SwitchOrgRequest defines the structure for switching organizations diff --git a/config.go b/config.go index e7b8c2f..707b3ec 100644 --- a/config.go +++ b/config.go @@ -8,6 +8,7 @@ import ( "path/filepath" "runtime" "strconv" + "strings" "time" ) @@ -21,9 +22,11 @@ type OlmConfig struct { UserToken string `json:"userToken"` // Network settings - MTU int `json:"mtu"` - DNS string `json:"dns"` - InterfaceName string `json:"interface"` + MTU int `json:"mtu"` + DNS string `json:"dns"` + DNSProxyIP string `json:"dnsProxyIP"` + UpstreamDNS []string `json:"upstreamDNS"` + InterfaceName string `json:"interface"` // Logging LogLevel string `json:"logLevel"` @@ -76,6 +79,8 @@ func DefaultConfig() *OlmConfig { config := &OlmConfig{ MTU: 1280, DNS: "8.8.8.8", + DNSProxyIP: "", + UpstreamDNS: []string{"8.8.8.8"}, LogLevel: "INFO", InterfaceName: "olm", EnableAPI: false, @@ -90,6 +95,8 @@ func DefaultConfig() *OlmConfig { // Track default sources config.sources["mtu"] = string(SourceDefault) config.sources["dns"] = string(SourceDefault) + config.sources["dnsProxyIP"] = string(SourceDefault) + config.sources["upstreamDNS"] = string(SourceDefault) config.sources["logLevel"] = string(SourceDefault) config.sources["interface"] = string(SourceDefault) config.sources["enableApi"] = string(SourceDefault) @@ -213,6 +220,14 @@ func loadConfigFromEnv(config *OlmConfig) { config.DNS = val config.sources["dns"] = string(SourceEnv) } + if val := os.Getenv("DNS_PROXY_IP"); val != "" { + config.DNSProxyIP = val + config.sources["dnsProxyIP"] = string(SourceEnv) + } + if val := os.Getenv("UPSTREAM_DNS"); val != "" { + config.UpstreamDNS = []string{val} + config.sources["upstreamDNS"] = string(SourceEnv) + } if val := os.Getenv("LOG_LEVEL"); val != "" { config.LogLevel = val config.sources["logLevel"] = string(SourceEnv) @@ -264,6 +279,8 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) { "userToken": config.UserToken, "mtu": config.MTU, "dns": config.DNS, + "dnsProxyIP": config.DNSProxyIP, + "upstreamDNS": fmt.Sprintf("%v", config.UpstreamDNS), "logLevel": config.LogLevel, "interface": config.InterfaceName, "httpAddr": config.HTTPAddr, @@ -283,6 +300,9 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) { serviceFlags.StringVar(&config.UserToken, "user-token", config.UserToken, "User token (optional)") serviceFlags.IntVar(&config.MTU, "mtu", config.MTU, "MTU to use") serviceFlags.StringVar(&config.DNS, "dns", config.DNS, "DNS server to use") + serviceFlags.StringVar(&config.DNSProxyIP, "dns-proxy-ip", config.DNSProxyIP, "IP address for the DNS proxy (required for DNS proxy)") + var upstreamDNSFlag string + serviceFlags.StringVar(&upstreamDNSFlag, "upstream-dns", "", "Upstream DNS server(s) (comma-separated, default: 8.8.8.8)") serviceFlags.StringVar(&config.LogLevel, "log-level", config.LogLevel, "Log level (DEBUG, INFO, WARN, ERROR, FATAL)") serviceFlags.StringVar(&config.InterfaceName, "interface", config.InterfaceName, "Name of the WireGuard interface") serviceFlags.StringVar(&config.HTTPAddr, "http-addr", config.HTTPAddr, "HTTP server address (e.g., ':9452')") @@ -301,6 +321,16 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) { return false, false, err } + // Parse upstream DNS flag if provided + if upstreamDNSFlag != "" { + config.UpstreamDNS = []string{} + for _, dns := range splitComma(upstreamDNSFlag) { + if dns != "" { + config.UpstreamDNS = append(config.UpstreamDNS, dns) + } + } + } + // Track which values were changed by CLI args if config.Endpoint != origValues["endpoint"].(string) { config.sources["endpoint"] = string(SourceCLI) @@ -323,6 +353,12 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) { if config.DNS != origValues["dns"].(string) { config.sources["dns"] = string(SourceCLI) } + if config.DNSProxyIP != origValues["dnsProxyIP"].(string) { + config.sources["dnsProxyIP"] = string(SourceCLI) + } + if fmt.Sprintf("%v", config.UpstreamDNS) != origValues["upstreamDNS"].(string) { + config.sources["upstreamDNS"] = string(SourceCLI) + } if config.LogLevel != origValues["logLevel"].(string) { config.sources["logLevel"] = string(SourceCLI) } @@ -418,6 +454,14 @@ func mergeConfigs(dest, src *OlmConfig) { dest.DNS = src.DNS dest.sources["dns"] = string(SourceFile) } + if src.DNSProxyIP != "" { + dest.DNSProxyIP = src.DNSProxyIP + dest.sources["dnsProxyIP"] = string(SourceFile) + } + if len(src.UpstreamDNS) > 0 && fmt.Sprintf("%v", src.UpstreamDNS) != "[8.8.8.8]" { + dest.UpstreamDNS = src.UpstreamDNS + dest.sources["upstreamDNS"] = string(SourceFile) + } if src.LogLevel != "" && src.LogLevel != "INFO" { dest.LogLevel = src.LogLevel dest.sources["logLevel"] = string(SourceFile) @@ -526,6 +570,8 @@ func (c *OlmConfig) ShowConfig() { fmt.Println("\nNetwork:") fmt.Printf(" mtu = %d [%s]\n", c.MTU, getSource("mtu")) fmt.Printf(" dns = %s [%s]\n", c.DNS, getSource("dns")) + fmt.Printf(" dns-proxy-ip = %s [%s]\n", formatValue("dnsProxyIP", c.DNSProxyIP), getSource("dnsProxyIP")) + fmt.Printf(" upstream-dns = %v [%s]\n", c.UpstreamDNS, getSource("upstreamDNS")) fmt.Printf(" interface = %s [%s]\n", c.InterfaceName, getSource("interface")) // Logging @@ -560,3 +606,16 @@ func (c *OlmConfig) ShowConfig() { fmt.Println("\nPriority: cli > environment > file > default") fmt.Println() } + +// splitComma splits a comma-separated string into a slice of trimmed strings +func splitComma(s string) []string { + parts := strings.Split(s, ",") + result := make([]string, 0, len(parts)) + for _, part := range parts { + trimmed := strings.TrimSpace(part) + if trimmed != "" { + result = append(result, trimmed) + } + } + return result +} diff --git a/dns/dns_proxy.go b/dns/dns_proxy.go index 4734b2c..3103c56 100644 --- a/dns/dns_proxy.go +++ b/dns/dns_proxy.go @@ -25,23 +25,19 @@ import ( ) const ( - // DNS proxy listening address - DNSProxyIP = "10.30.30.30" - DNSPort = 53 - - // Upstream DNS servers - UpstreamDNS1 = "8.8.8.8:53" - UpstreamDNS2 = "8.8.4.4:53" + DNSPort = 53 ) // DNSProxy implements a DNS proxy using gvisor netstack type DNSProxy struct { - stack *stack.Stack - ep *channel.Endpoint - proxyIP netip.Addr - mtu int - tunDevice tun.Device // Direct reference to underlying TUN device for responses - recordStore *DNSRecordStore // Local DNS records + stack *stack.Stack + ep *channel.Endpoint + proxyIP netip.Addr + upstreamDNS []string + mtu int + tunDevice tun.Device // Direct reference to underlying TUN device for responses + middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering + recordStore *DNSRecordStore // Local DNS records ctx context.Context cancel context.CancelFunc @@ -49,12 +45,16 @@ type DNSProxy struct { } // NewDNSProxy creates a new DNS proxy -func NewDNSProxy(tunDevice tun.Device, mtu int) (*DNSProxy, error) { - proxyIP, err := netip.ParseAddr(DNSProxyIP) +func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu int, dnsProxyIP string, upstreamDns []string) (*DNSProxy, error) { + proxyIP, err := netip.ParseAddr(dnsProxyIP) if err != nil { return nil, fmt.Errorf("invalid proxy IP: %w", err) } + if len(upstreamDns) == 0 { + return nil, fmt.Errorf("at least one upstream DNS server must be specified") + } + ctx, cancel := context.WithCancel(context.Background()) proxy := &DNSProxy{ @@ -82,9 +82,11 @@ func NewDNSProxy(tunDevice tun.Device, mtu int) (*DNSProxy, error) { } // Add IP address + // Parse the proxy IP to get the octets + ipBytes := proxyIP.As4() protoAddr := tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.AddrFrom4([4]byte{10, 30, 30, 30}).WithPrefix(), + AddressWithPrefix: tcpip.AddrFrom4(ipBytes).WithPrefix(), } if err := proxy.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil { @@ -101,23 +103,23 @@ func NewDNSProxy(tunDevice tun.Device, mtu int) (*DNSProxy, error) { } // Start starts the DNS proxy and registers with the filter -func (p *DNSProxy) Start(device *device.MiddleDevice) error { +func (p *DNSProxy) Start() error { // Install packet filter rule - device.AddRule(p.proxyIP, p.handlePacket) + p.middleDevice.AddRule(p.proxyIP, p.handlePacket) // Start DNS listener p.wg.Add(2) go p.runDNSListener() go p.runPacketSender() - logger.Info("DNS proxy started on %s:%d", DNSProxyIP, DNSPort) + logger.Info("DNS proxy started on %s:%d", p.proxyIP.String(), DNSPort) return nil } // Stop stops the DNS proxy -func (p *DNSProxy) Stop(device *device.MiddleDevice) { - if device != nil { - device.RemoveRule(p.proxyIP) +func (p *DNSProxy) Stop() { + if p.middleDevice != nil { + p.middleDevice.RemoveRule(p.proxyIP) } p.cancel() p.wg.Wait() @@ -174,9 +176,11 @@ func (p *DNSProxy) runDNSListener() { defer p.wg.Done() // Create UDP listener using gonet + // Parse the proxy IP to get the octets + ipBytes := p.proxyIP.As4() laddr := &tcpip.FullAddress{ NIC: 1, - Addr: tcpip.AddrFrom4([4]byte{10, 30, 30, 30}), + Addr: tcpip.AddrFrom4(ipBytes), Port: DNSPort, } @@ -322,11 +326,11 @@ func (p *DNSProxy) checkLocalRecords(query *dns.Msg, question dns.Question) *dns // forwardToUpstream forwards a DNS query to upstream DNS servers func (p *DNSProxy) forwardToUpstream(query *dns.Msg) *dns.Msg { // Try primary DNS server - response, err := p.queryUpstream(UpstreamDNS1, query, 2*time.Second) - if err != nil { + response, err := p.queryUpstream(p.upstreamDNS[0], query, 2*time.Second) + if err != nil && len(p.upstreamDNS) > 1 { // Try secondary DNS server logger.Debug("Primary DNS failed, trying secondary: %v", err) - response, err = p.queryUpstream(UpstreamDNS2, query, 2*time.Second) + response, err = p.queryUpstream(p.upstreamDNS[1], query, 2*time.Second) if err != nil { logger.Error("Both DNS servers failed: %v", err) return nil diff --git a/main.go b/main.go index 548cd42..a6a508d 100644 --- a/main.go +++ b/main.go @@ -226,6 +226,8 @@ func runOlmMainWithArgs(ctx context.Context, args []string) { UserToken: config.UserToken, MTU: config.MTU, DNS: config.DNS, + DNSProxyIP: config.DNSProxyIP, + UpstreamDNS: config.UpstreamDNS, InterfaceName: config.InterfaceName, Holepunch: config.Holepunch, TlsClientCert: config.TlsClientCert, diff --git a/olm/olm.go b/olm/olm.go index 94098cb..178e6d5 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -47,6 +47,8 @@ type TunnelConfig struct { // Network settings MTU int DNS string + DNSProxyIP string + UpstreamDNS []string InterfaceName string // Advanced @@ -124,6 +126,8 @@ func Init(ctx context.Context, config GlobalConfig) { UserToken: req.UserToken, MTU: req.MTU, DNS: req.DNS, + DNSProxyIP: req.DNSProxyIP, + UpstreamDNS: req.UpstreamDNS, InterfaceName: req.InterfaceName, Holepunch: req.Holepunch, TlsClientCert: req.TlsClientCert, @@ -157,6 +161,11 @@ func Init(ctx context.Context, config GlobalConfig) { if req.DNS == "" { tunnelConfig.DNS = "9.9.9.9" } + // DNSProxyIP has no default - it must be provided if DNS proxy is desired + // UpstreamDNS defaults to 8.8.8.8 if not provided + if len(req.UpstreamDNS) == 0 { + tunnelConfig.UpstreamDNS = []string{"8.8.8.8"} + } if req.InterfaceName == "" { tunnelConfig.InterfaceName = "olm" } @@ -473,25 +482,26 @@ func StartTunnel(config TunnelConfig) { logger.Error("Failed to bring up WireGuard device: %v", err) } - // Create and start DNS proxy - dnsProxy, err = dns.NewDNSProxy(tdev, config.MTU) - if err != nil { - logger.Error("Failed to create DNS proxy: %v", err) - } - if err := dnsProxy.Start(middleDev); err != nil { - logger.Error("Failed to start DNS proxy: %v", err) - } - ip := net.ParseIP("192.168.1.100") - if dnsProxy.AddDNSRecord("example.com", ip); err != nil { - logger.Error("Failed to add DNS record: %v", err) + if config.DNSProxyIP != "" { + // Create and start DNS proxy + dnsProxy, err = dns.NewDNSProxy(tdev, middleDev, config.MTU, config.DNSProxyIP, config.UpstreamDNS) + if err != nil { + logger.Error("Failed to create DNS proxy: %v", err) + } + + if err := dnsProxy.Start(); err != nil { + logger.Error("Failed to start DNS proxy: %v", err) + } } if err = ConfigureInterface(interfaceName, wgData, config.MTU); err != nil { logger.Error("Failed to configure interface: %v", err) } - if addRoutes([]string{"10.30.30.30/32"}, interfaceName); err != nil { - logger.Error("Failed to add route for DNS server: %v", err) + if config.DNSProxyIP != "" { + if addRoutes([]string{config.DNSProxyIP + "/32"}, interfaceName); err != nil { + logger.Error("Failed to add route for DNS server: %v", err) + } } // TODO: seperate adding the callback to this so we can init it above with the interface @@ -661,22 +671,12 @@ func StartTunnel(config TunnelConfig) { return } - var addData AddPeerData - if err := json.Unmarshal(jsonData, &addData); err != nil { + var siteConfig SiteConfig + if err := json.Unmarshal(jsonData, &siteConfig); err != nil { logger.Error("Error unmarshaling add data: %v", err) return } - // Convert to SiteConfig - siteConfig := SiteConfig{ - SiteId: addData.SiteId, - Endpoint: addData.Endpoint, - PublicKey: addData.PublicKey, - ServerIP: addData.ServerIP, - ServerPort: addData.ServerPort, - RemoteSubnets: addData.RemoteSubnets, - } - // Add the peer to WireGuard if dev == nil { logger.Error("WireGuard device not initialized") @@ -699,7 +699,7 @@ func StartTunnel(config TunnelConfig) { } // Add successful - logger.Info("Successfully added peer for site %d", addData.SiteId) + logger.Info("Successfully added peer for site %d", siteConfig.SiteId) // Update WgData with the new peer wgData.Sites = append(wgData.Sites, siteConfig) @@ -1076,7 +1076,7 @@ func Close() { // Stop DNS proxy if dnsProxy != nil { - dnsProxy.Stop(middleDev) + dnsProxy.Stop() dnsProxy = nil } diff --git a/olm/types.go b/olm/types.go index 4610aa6..96f63b9 100644 --- a/olm/types.go +++ b/olm/types.go @@ -1,17 +1,9 @@ package olm type WgData struct { - Sites []SiteConfig `json:"sites"` - TunnelIP string `json:"tunnelIP"` -} - -type SiteConfig struct { - SiteId int `json:"siteId"` - Endpoint string `json:"endpoint"` - PublicKey string `json:"publicKey"` - ServerIP string `json:"serverIP"` - ServerPort uint16 `json:"serverPort"` - RemoteSubnets []string `json:"remoteSubnets,omitempty"` // optional, array of subnets that this site can access + Sites []SiteConfig `json:"sites"` + TunnelIP string `json:"tunnelIP"` + UtilitySubnet string `json:"utilitySubnet"` // this is for things like the DNS server, and alias addresses } type HolePunchMessage struct { @@ -40,23 +32,19 @@ type PeerAction struct { } // UpdatePeerData represents the data needed to update a peer -type UpdatePeerData struct { +type SiteConfig struct { SiteId int `json:"siteId"` Endpoint string `json:"endpoint,omitempty"` PublicKey string `json:"publicKey,omitempty"` ServerIP string `json:"serverIP,omitempty"` ServerPort uint16 `json:"serverPort,omitempty"` RemoteSubnets []string `json:"remoteSubnets,omitempty"` // optional, array of subnets that this site can access + Aliases []Alias `json:"aliases,omitempty"` // optional, array of alias configurations } -// AddPeerData represents the data needed to add a peer -type AddPeerData struct { - SiteId int `json:"siteId"` - Endpoint string `json:"endpoint"` - PublicKey string `json:"publicKey"` - ServerIP string `json:"serverIP"` - ServerPort uint16 `json:"serverPort"` - RemoteSubnets []string `json:"remoteSubnets,omitempty"` // optional, array of subnets that this site can access +type Alias struct { + Alias string `json:"alias"` // the alias name + AliasAddress string `json:"aliasAddress"` // the alias IP address } // RemovePeerData represents the data needed to remove a peer