diff --git a/config.go b/config.go index 1c98719..6f76893 100644 --- a/config.go +++ b/config.go @@ -78,7 +78,7 @@ func DefaultConfig() *OlmConfig { config := &OlmConfig{ MTU: 1280, DNS: "8.8.8.8", - UpstreamDNS: []string{"8.8.8.8"}, + UpstreamDNS: []string{"8.8.8.8:53"}, LogLevel: "INFO", InterfaceName: "olm", EnableAPI: false, @@ -293,7 +293,7 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) { serviceFlags.IntVar(&config.MTU, "mtu", config.MTU, "MTU to use") serviceFlags.StringVar(&config.DNS, "dns", config.DNS, "DNS server to use") var upstreamDNSFlag string - serviceFlags.StringVar(&upstreamDNSFlag, "upstream-dns", "", "Upstream DNS server(s) (comma-separated, default: 8.8.8.8)") + serviceFlags.StringVar(&upstreamDNSFlag, "upstream-dns", "", "Upstream DNS server(s) (comma-separated, default: 8.8.8.8:53)") serviceFlags.StringVar(&config.LogLevel, "log-level", config.LogLevel, "Log level (DEBUG, INFO, WARN, ERROR, FATAL)") serviceFlags.StringVar(&config.InterfaceName, "interface", config.InterfaceName, "Name of the WireGuard interface") serviceFlags.StringVar(&config.HTTPAddr, "http-addr", config.HTTPAddr, "HTTP server address (e.g., ':9452')") @@ -442,7 +442,7 @@ func mergeConfigs(dest, src *OlmConfig) { dest.DNS = src.DNS dest.sources["dns"] = string(SourceFile) } - if len(src.UpstreamDNS) > 0 && fmt.Sprintf("%v", src.UpstreamDNS) != "[8.8.8.8]" { + if len(src.UpstreamDNS) > 0 && fmt.Sprintf("%v", src.UpstreamDNS) != "[8.8.8.8:53]" { dest.UpstreamDNS = src.UpstreamDNS dest.sources["upstreamDNS"] = string(SourceFile) } diff --git a/dns/dns_proxy.go b/dns/dns_proxy.go index c449fe5..7bb644c 100644 --- a/dns/dns_proxy.go +++ b/dns/dns_proxy.go @@ -58,12 +58,14 @@ func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu in ctx, cancel := context.WithCancel(context.Background()) proxy := &DNSProxy{ - proxyIP: proxyIP, - mtu: mtu, - tunDevice: tunDevice, - recordStore: NewDNSRecordStore(), - ctx: ctx, - cancel: cancel, + proxyIP: proxyIP, + mtu: mtu, + tunDevice: tunDevice, + middleDevice: middleDevice, + upstreamDNS: upstreamDns, + recordStore: NewDNSRecordStore(), + ctx: ctx, + cancel: cancel, } // Create gvisor netstack @@ -134,6 +136,10 @@ func (p *DNSProxy) Stop() { logger.Info("DNS proxy stopped") } +func (p *DNSProxy) GetProxyIP() netip.Addr { + return p.proxyIP +} + // handlePacket is called by the filter for packets destined to DNS proxy IP func (p *DNSProxy) handlePacket(packet []byte) bool { if len(packet) < 20 { @@ -248,7 +254,7 @@ func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clie // If no local records, forward to upstream if response == nil { - logger.Debug("No local record for %s, forwarding upstream", question.Name) + logger.Debug("No local record for %s, forwarding upstream to %v", question.Name, p.upstreamDNS) response = p.forwardToUpstream(msg) } diff --git a/olm/olm.go b/olm/olm.go index f3431e2..1b4ca39 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -4,7 +4,9 @@ import ( "context" "encoding/json" "fmt" + "log" "net" + "net/netip" "runtime" "strings" "time" @@ -16,6 +18,7 @@ import ( "github.com/fosrl/olm/api" middleDevice "github.com/fosrl/olm/device" "github.com/fosrl/olm/dns" + platform "github.com/fosrl/olm/dns/platform" "github.com/fosrl/olm/network" "github.com/fosrl/olm/peermonitor" "github.com/fosrl/olm/websocket" @@ -91,6 +94,7 @@ var ( globalCtx context.Context stopRegister func() stopPing chan struct{} + configurator platform.DNSConfigurator ) func Init(ctx context.Context, config GlobalConfig) { @@ -167,7 +171,7 @@ func Init(ctx context.Context, config GlobalConfig) { // DNSProxyIP has no default - it must be provided if DNS proxy is desired // UpstreamDNS defaults to 8.8.8.8 if not provided if len(req.UpstreamDNS) == 0 { - tunnelConfig.UpstreamDNS = []string{"8.8.8.8"} + tunnelConfig.UpstreamDNS = []string{"8.8.8.8:53"} } if req.InterfaceName == "" { tunnelConfig.InterfaceName = "olm" @@ -485,6 +489,9 @@ func StartTunnel(config TunnelConfig) { logger.Error("Failed to bring up WireGuard device: %v", err) } + // TODO: REMOVE HARDCODE + wgData.UtilitySubnet = "100.81.0.0/24" + // Create and start DNS proxy dnsProxy, err = dns.NewDNSProxy(tdev, middleDev, config.MTU, wgData.UtilitySubnet, config.UpstreamDNS) if err != nil { @@ -570,6 +577,37 @@ func StartTunnel(config TunnelConfig) { peerMonitor.Start() + configurator, err = platform.DetectBestConfigurator(interfaceName) + if err != nil { + log.Fatalf("Failed to detect DNS configurator: %v", err) + } + + fmt.Printf("Using DNS configurator: %s\n", configurator.Name()) + + // Get current DNS servers before changing + currentDNS, err := configurator.GetCurrentDNS() + if err != nil { + log.Printf("Warning: Could not get current DNS: %v", err) + } else { + fmt.Printf("Current DNS servers: %v\n", currentDNS) + } + + // Set new DNS servers + newDNS := []netip.Addr{ + dnsProxy.GetProxyIP(), + // netip.MustParseAddr("8.8.8.8"), // Google + } + + fmt.Printf("Setting DNS servers to: %v\n", newDNS) + originalDNS, err := configurator.SetDNS(newDNS) + if err != nil { + log.Fatalf("Failed to set DNS: %v", err) + } + + for _, addr := range originalDNS { + fmt.Printf("Original DNS server: %v\n", addr) + } + if err := dnsProxy.Start(); err != nil { logger.Error("Failed to start DNS proxy: %v", err) } @@ -1110,6 +1148,14 @@ func Close() { middleDev = nil } + // Restore original DNS + if configurator != nil { + fmt.Println("Restoring original DNS servers...") + if err := configurator.RestoreDNS(); err != nil { + log.Fatalf("Failed to restore DNS: %v", err) + } + } + // Stop DNS proxy logger.Debug("Stopping DNS proxy") if dnsProxy != nil { diff --git a/olm/windows.go b/olm/windows.go index 772e51a..b168930 100644 --- a/olm/windows.go +++ b/olm/windows.go @@ -11,7 +11,7 @@ import ( "golang.zx2c4.com/wireguard/tun" ) -func createTUNFromFD(tunFdStr string, mtuInt int) (tun.Device, error) { +func createTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) { return nil, errors.New("CreateTUNFromFile not supported on Windows") }