diff --git a/config.go b/config.go index 707b3ec..1c98719 100644 --- a/config.go +++ b/config.go @@ -24,7 +24,6 @@ type OlmConfig struct { // Network settings MTU int `json:"mtu"` DNS string `json:"dns"` - DNSProxyIP string `json:"dnsProxyIP"` UpstreamDNS []string `json:"upstreamDNS"` InterfaceName string `json:"interface"` @@ -79,7 +78,6 @@ func DefaultConfig() *OlmConfig { config := &OlmConfig{ MTU: 1280, DNS: "8.8.8.8", - DNSProxyIP: "", UpstreamDNS: []string{"8.8.8.8"}, LogLevel: "INFO", InterfaceName: "olm", @@ -95,7 +93,6 @@ func DefaultConfig() *OlmConfig { // Track default sources config.sources["mtu"] = string(SourceDefault) config.sources["dns"] = string(SourceDefault) - config.sources["dnsProxyIP"] = string(SourceDefault) config.sources["upstreamDNS"] = string(SourceDefault) config.sources["logLevel"] = string(SourceDefault) config.sources["interface"] = string(SourceDefault) @@ -220,10 +217,6 @@ func loadConfigFromEnv(config *OlmConfig) { config.DNS = val config.sources["dns"] = string(SourceEnv) } - if val := os.Getenv("DNS_PROXY_IP"); val != "" { - config.DNSProxyIP = val - config.sources["dnsProxyIP"] = string(SourceEnv) - } if val := os.Getenv("UPSTREAM_DNS"); val != "" { config.UpstreamDNS = []string{val} config.sources["upstreamDNS"] = string(SourceEnv) @@ -279,7 +272,6 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) { "userToken": config.UserToken, "mtu": config.MTU, "dns": config.DNS, - "dnsProxyIP": config.DNSProxyIP, "upstreamDNS": fmt.Sprintf("%v", config.UpstreamDNS), "logLevel": config.LogLevel, "interface": config.InterfaceName, @@ -300,7 +292,6 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) { serviceFlags.StringVar(&config.UserToken, "user-token", config.UserToken, "User token (optional)") serviceFlags.IntVar(&config.MTU, "mtu", config.MTU, "MTU to use") serviceFlags.StringVar(&config.DNS, "dns", config.DNS, "DNS server to use") - serviceFlags.StringVar(&config.DNSProxyIP, "dns-proxy-ip", config.DNSProxyIP, "IP address for the DNS proxy (required for DNS proxy)") var upstreamDNSFlag string serviceFlags.StringVar(&upstreamDNSFlag, "upstream-dns", "", "Upstream DNS server(s) (comma-separated, default: 8.8.8.8)") serviceFlags.StringVar(&config.LogLevel, "log-level", config.LogLevel, "Log level (DEBUG, INFO, WARN, ERROR, FATAL)") @@ -353,9 +344,6 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) { if config.DNS != origValues["dns"].(string) { config.sources["dns"] = string(SourceCLI) } - if config.DNSProxyIP != origValues["dnsProxyIP"].(string) { - config.sources["dnsProxyIP"] = string(SourceCLI) - } if fmt.Sprintf("%v", config.UpstreamDNS) != origValues["upstreamDNS"].(string) { config.sources["upstreamDNS"] = string(SourceCLI) } @@ -454,10 +442,6 @@ func mergeConfigs(dest, src *OlmConfig) { dest.DNS = src.DNS dest.sources["dns"] = string(SourceFile) } - if src.DNSProxyIP != "" { - dest.DNSProxyIP = src.DNSProxyIP - dest.sources["dnsProxyIP"] = string(SourceFile) - } if len(src.UpstreamDNS) > 0 && fmt.Sprintf("%v", src.UpstreamDNS) != "[8.8.8.8]" { dest.UpstreamDNS = src.UpstreamDNS dest.sources["upstreamDNS"] = string(SourceFile) @@ -570,7 +554,6 @@ func (c *OlmConfig) ShowConfig() { fmt.Println("\nNetwork:") fmt.Printf(" mtu = %d [%s]\n", c.MTU, getSource("mtu")) fmt.Printf(" dns = %s [%s]\n", c.DNS, getSource("dns")) - fmt.Printf(" dns-proxy-ip = %s [%s]\n", formatValue("dnsProxyIP", c.DNSProxyIP), getSource("dnsProxyIP")) fmt.Printf(" upstream-dns = %v [%s]\n", c.UpstreamDNS, getSource("upstreamDNS")) fmt.Printf(" interface = %s [%s]\n", c.InterfaceName, getSource("interface")) diff --git a/dns/dns_proxy.go b/dns/dns_proxy.go index 3103c56..c449fe5 100644 --- a/dns/dns_proxy.go +++ b/dns/dns_proxy.go @@ -45,10 +45,10 @@ type DNSProxy struct { } // NewDNSProxy creates a new DNS proxy -func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu int, dnsProxyIP string, upstreamDns []string) (*DNSProxy, error) { - proxyIP, err := netip.ParseAddr(dnsProxyIP) +func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string) (*DNSProxy, error) { + proxyIP, err := PickIPFromSubnet(utilitySubnet) if err != nil { - return nil, fmt.Errorf("invalid proxy IP: %w", err) + return nil, fmt.Errorf("failed to pick DNS proxy IP from subnet: %v", err) } if len(upstreamDns) == 0 { @@ -430,3 +430,19 @@ func (p *DNSProxy) GetDNSRecords(domain string, recordType RecordType) []net.IP func (p *DNSProxy) ClearDNSRecords() { p.recordStore.Clear() } + +func PickIPFromSubnet(subnet string) (netip.Addr, error) { + // given a subnet in CIDR notation, pick the first usable IP + prefix, err := netip.ParsePrefix(subnet) + if err != nil { + return netip.Addr{}, fmt.Errorf("invalid subnet: %w", err) + } + + // Pick the first usable IP address from the subnet + ip := prefix.Addr().Next() + if !ip.IsValid() { + return netip.Addr{}, fmt.Errorf("no valid IP address found in subnet: %s", subnet) + } + + return ip, nil +} diff --git a/main.go b/main.go index a6a508d..fc559bc 100644 --- a/main.go +++ b/main.go @@ -226,7 +226,6 @@ func runOlmMainWithArgs(ctx context.Context, args []string) { UserToken: config.UserToken, MTU: config.MTU, DNS: config.DNS, - DNSProxyIP: config.DNSProxyIP, UpstreamDNS: config.UpstreamDNS, InterfaceName: config.InterfaceName, Holepunch: config.Holepunch, diff --git a/olm/olm.go b/olm/olm.go index 25a3bea..f3431e2 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -52,7 +52,6 @@ type TunnelConfig struct { // Network settings MTU int DNS string - DNSProxyIP string UpstreamDNS []string InterfaceName string @@ -131,7 +130,6 @@ func Init(ctx context.Context, config GlobalConfig) { UserToken: req.UserToken, MTU: req.MTU, DNS: req.DNS, - DNSProxyIP: req.DNSProxyIP, UpstreamDNS: req.UpstreamDNS, InterfaceName: req.InterfaceName, Holepunch: req.Holepunch, @@ -487,26 +485,18 @@ func StartTunnel(config TunnelConfig) { logger.Error("Failed to bring up WireGuard device: %v", err) } - if config.DNSProxyIP != "" { - // Create and start DNS proxy - dnsProxy, err = dns.NewDNSProxy(tdev, middleDev, config.MTU, config.DNSProxyIP, config.UpstreamDNS) - if err != nil { - logger.Error("Failed to create DNS proxy: %v", err) - } - - if err := dnsProxy.Start(); err != nil { - logger.Error("Failed to start DNS proxy: %v", err) - } + // Create and start DNS proxy + dnsProxy, err = dns.NewDNSProxy(tdev, middleDev, config.MTU, wgData.UtilitySubnet, config.UpstreamDNS) + if err != nil { + logger.Error("Failed to create DNS proxy: %v", err) } if err = ConfigureInterface(interfaceName, wgData, config.MTU); err != nil { logger.Error("Failed to configure interface: %v", err) } - if config.DNSProxyIP != "" { - if addRoutes([]string{config.DNSProxyIP + "/32"}, interfaceName); err != nil { - logger.Error("Failed to add route for DNS server: %v", err) - } + if addRoutes([]string{wgData.UtilitySubnet}, interfaceName); err != nil { // also route the utility subnet + logger.Error("Failed to add route for utility subnet: %v", err) } // TODO: seperate adding the callback to this so we can init it above with the interface @@ -565,16 +555,14 @@ func StartTunnel(config TunnelConfig) { } for _, alias := range site.Aliases { - if dnsProxy != nil { // some times this is not initialized - // try to parse the alias address into net.IP - address := net.ParseIP(alias.AliasAddress) - if address == nil { - logger.Warn("Invalid alias address for %s: %s", alias.Alias, alias.AliasAddress) - continue - } - - dnsProxy.AddDNSRecord(alias.Alias, address) + // try to parse the alias address into net.IP + address := net.ParseIP(alias.AliasAddress) + if address == nil { + logger.Warn("Invalid alias address for %s: %s", alias.Alias, alias.AliasAddress) + continue } + + dnsProxy.AddDNSRecord(alias.Alias, address) } logger.Info("Configured peer %s", site.PublicKey) @@ -582,6 +570,10 @@ func StartTunnel(config TunnelConfig) { peerMonitor.Start() + if err := dnsProxy.Start(); err != nil { + logger.Error("Failed to start DNS proxy: %v", err) + } + apiServer.SetRegistered(true) connected = true