Basic working

Former-commit-id: 4dd50526cf
This commit is contained in:
Owen
2025-11-23 22:01:43 -05:00
parent 50008f3c12
commit ead8fab70a
4 changed files with 64 additions and 12 deletions

View File

@@ -78,7 +78,7 @@ func DefaultConfig() *OlmConfig {
config := &OlmConfig{ config := &OlmConfig{
MTU: 1280, MTU: 1280,
DNS: "8.8.8.8", DNS: "8.8.8.8",
UpstreamDNS: []string{"8.8.8.8"}, UpstreamDNS: []string{"8.8.8.8:53"},
LogLevel: "INFO", LogLevel: "INFO",
InterfaceName: "olm", InterfaceName: "olm",
EnableAPI: false, EnableAPI: false,
@@ -293,7 +293,7 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
serviceFlags.IntVar(&config.MTU, "mtu", config.MTU, "MTU to use") serviceFlags.IntVar(&config.MTU, "mtu", config.MTU, "MTU to use")
serviceFlags.StringVar(&config.DNS, "dns", config.DNS, "DNS server to use") serviceFlags.StringVar(&config.DNS, "dns", config.DNS, "DNS server to use")
var upstreamDNSFlag string var upstreamDNSFlag string
serviceFlags.StringVar(&upstreamDNSFlag, "upstream-dns", "", "Upstream DNS server(s) (comma-separated, default: 8.8.8.8)") serviceFlags.StringVar(&upstreamDNSFlag, "upstream-dns", "", "Upstream DNS server(s) (comma-separated, default: 8.8.8.8:53)")
serviceFlags.StringVar(&config.LogLevel, "log-level", config.LogLevel, "Log level (DEBUG, INFO, WARN, ERROR, FATAL)") serviceFlags.StringVar(&config.LogLevel, "log-level", config.LogLevel, "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
serviceFlags.StringVar(&config.InterfaceName, "interface", config.InterfaceName, "Name of the WireGuard interface") serviceFlags.StringVar(&config.InterfaceName, "interface", config.InterfaceName, "Name of the WireGuard interface")
serviceFlags.StringVar(&config.HTTPAddr, "http-addr", config.HTTPAddr, "HTTP server address (e.g., ':9452')") serviceFlags.StringVar(&config.HTTPAddr, "http-addr", config.HTTPAddr, "HTTP server address (e.g., ':9452')")
@@ -442,7 +442,7 @@ func mergeConfigs(dest, src *OlmConfig) {
dest.DNS = src.DNS dest.DNS = src.DNS
dest.sources["dns"] = string(SourceFile) dest.sources["dns"] = string(SourceFile)
} }
if len(src.UpstreamDNS) > 0 && fmt.Sprintf("%v", src.UpstreamDNS) != "[8.8.8.8]" { if len(src.UpstreamDNS) > 0 && fmt.Sprintf("%v", src.UpstreamDNS) != "[8.8.8.8:53]" {
dest.UpstreamDNS = src.UpstreamDNS dest.UpstreamDNS = src.UpstreamDNS
dest.sources["upstreamDNS"] = string(SourceFile) dest.sources["upstreamDNS"] = string(SourceFile)
} }

View File

@@ -58,12 +58,14 @@ func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu in
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
proxy := &DNSProxy{ proxy := &DNSProxy{
proxyIP: proxyIP, proxyIP: proxyIP,
mtu: mtu, mtu: mtu,
tunDevice: tunDevice, tunDevice: tunDevice,
recordStore: NewDNSRecordStore(), middleDevice: middleDevice,
ctx: ctx, upstreamDNS: upstreamDns,
cancel: cancel, recordStore: NewDNSRecordStore(),
ctx: ctx,
cancel: cancel,
} }
// Create gvisor netstack // Create gvisor netstack
@@ -134,6 +136,10 @@ func (p *DNSProxy) Stop() {
logger.Info("DNS proxy stopped") logger.Info("DNS proxy stopped")
} }
func (p *DNSProxy) GetProxyIP() netip.Addr {
return p.proxyIP
}
// handlePacket is called by the filter for packets destined to DNS proxy IP // handlePacket is called by the filter for packets destined to DNS proxy IP
func (p *DNSProxy) handlePacket(packet []byte) bool { func (p *DNSProxy) handlePacket(packet []byte) bool {
if len(packet) < 20 { if len(packet) < 20 {
@@ -248,7 +254,7 @@ func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clie
// If no local records, forward to upstream // If no local records, forward to upstream
if response == nil { if response == nil {
logger.Debug("No local record for %s, forwarding upstream", question.Name) logger.Debug("No local record for %s, forwarding upstream to %v", question.Name, p.upstreamDNS)
response = p.forwardToUpstream(msg) response = p.forwardToUpstream(msg)
} }

View File

@@ -4,7 +4,9 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"log"
"net" "net"
"net/netip"
"runtime" "runtime"
"strings" "strings"
"time" "time"
@@ -16,6 +18,7 @@ import (
"github.com/fosrl/olm/api" "github.com/fosrl/olm/api"
middleDevice "github.com/fosrl/olm/device" middleDevice "github.com/fosrl/olm/device"
"github.com/fosrl/olm/dns" "github.com/fosrl/olm/dns"
platform "github.com/fosrl/olm/dns/platform"
"github.com/fosrl/olm/network" "github.com/fosrl/olm/network"
"github.com/fosrl/olm/peermonitor" "github.com/fosrl/olm/peermonitor"
"github.com/fosrl/olm/websocket" "github.com/fosrl/olm/websocket"
@@ -91,6 +94,7 @@ var (
globalCtx context.Context globalCtx context.Context
stopRegister func() stopRegister func()
stopPing chan struct{} stopPing chan struct{}
configurator platform.DNSConfigurator
) )
func Init(ctx context.Context, config GlobalConfig) { func Init(ctx context.Context, config GlobalConfig) {
@@ -167,7 +171,7 @@ func Init(ctx context.Context, config GlobalConfig) {
// DNSProxyIP has no default - it must be provided if DNS proxy is desired // DNSProxyIP has no default - it must be provided if DNS proxy is desired
// UpstreamDNS defaults to 8.8.8.8 if not provided // UpstreamDNS defaults to 8.8.8.8 if not provided
if len(req.UpstreamDNS) == 0 { if len(req.UpstreamDNS) == 0 {
tunnelConfig.UpstreamDNS = []string{"8.8.8.8"} tunnelConfig.UpstreamDNS = []string{"8.8.8.8:53"}
} }
if req.InterfaceName == "" { if req.InterfaceName == "" {
tunnelConfig.InterfaceName = "olm" tunnelConfig.InterfaceName = "olm"
@@ -485,6 +489,9 @@ func StartTunnel(config TunnelConfig) {
logger.Error("Failed to bring up WireGuard device: %v", err) logger.Error("Failed to bring up WireGuard device: %v", err)
} }
// TODO: REMOVE HARDCODE
wgData.UtilitySubnet = "100.81.0.0/24"
// Create and start DNS proxy // Create and start DNS proxy
dnsProxy, err = dns.NewDNSProxy(tdev, middleDev, config.MTU, wgData.UtilitySubnet, config.UpstreamDNS) dnsProxy, err = dns.NewDNSProxy(tdev, middleDev, config.MTU, wgData.UtilitySubnet, config.UpstreamDNS)
if err != nil { if err != nil {
@@ -570,6 +577,37 @@ func StartTunnel(config TunnelConfig) {
peerMonitor.Start() peerMonitor.Start()
configurator, err = platform.DetectBestConfigurator(interfaceName)
if err != nil {
log.Fatalf("Failed to detect DNS configurator: %v", err)
}
fmt.Printf("Using DNS configurator: %s\n", configurator.Name())
// Get current DNS servers before changing
currentDNS, err := configurator.GetCurrentDNS()
if err != nil {
log.Printf("Warning: Could not get current DNS: %v", err)
} else {
fmt.Printf("Current DNS servers: %v\n", currentDNS)
}
// Set new DNS servers
newDNS := []netip.Addr{
dnsProxy.GetProxyIP(),
// netip.MustParseAddr("8.8.8.8"), // Google
}
fmt.Printf("Setting DNS servers to: %v\n", newDNS)
originalDNS, err := configurator.SetDNS(newDNS)
if err != nil {
log.Fatalf("Failed to set DNS: %v", err)
}
for _, addr := range originalDNS {
fmt.Printf("Original DNS server: %v\n", addr)
}
if err := dnsProxy.Start(); err != nil { if err := dnsProxy.Start(); err != nil {
logger.Error("Failed to start DNS proxy: %v", err) logger.Error("Failed to start DNS proxy: %v", err)
} }
@@ -1110,6 +1148,14 @@ func Close() {
middleDev = nil middleDev = nil
} }
// Restore original DNS
if configurator != nil {
fmt.Println("Restoring original DNS servers...")
if err := configurator.RestoreDNS(); err != nil {
log.Fatalf("Failed to restore DNS: %v", err)
}
}
// Stop DNS proxy // Stop DNS proxy
logger.Debug("Stopping DNS proxy") logger.Debug("Stopping DNS proxy")
if dnsProxy != nil { if dnsProxy != nil {

View File

@@ -11,7 +11,7 @@ import (
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
) )
func createTUNFromFD(tunFdStr string, mtuInt int) (tun.Device, error) { func createTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
return nil, errors.New("CreateTUNFromFile not supported on Windows") return nil, errors.New("CreateTUNFromFile not supported on Windows")
} }