Working on sending down the dns

Former-commit-id: 1a8385c457
This commit is contained in:
Owen
2025-11-23 15:57:35 -05:00
parent b38357875e
commit 6c7ee31330
6 changed files with 143 additions and 88 deletions

View File

@@ -13,18 +13,20 @@ import (
// ConnectionRequest defines the structure for an incoming connection request
type ConnectionRequest struct {
ID string `json:"id"`
Secret string `json:"secret"`
Endpoint string `json:"endpoint"`
UserToken string `json:"userToken,omitempty"`
MTU int `json:"mtu,omitempty"`
DNS string `json:"dns,omitempty"`
InterfaceName string `json:"interfaceName,omitempty"`
Holepunch bool `json:"holepunch,omitempty"`
TlsClientCert string `json:"tlsClientCert,omitempty"`
PingInterval string `json:"pingInterval,omitempty"`
PingTimeout string `json:"pingTimeout,omitempty"`
OrgID string `json:"orgId,omitempty"`
ID string `json:"id"`
Secret string `json:"secret"`
Endpoint string `json:"endpoint"`
UserToken string `json:"userToken,omitempty"`
MTU int `json:"mtu,omitempty"`
DNS string `json:"dns,omitempty"`
DNSProxyIP string `json:"dnsProxyIP,omitempty"`
UpstreamDNS []string `json:"upstreamDNS,omitempty"`
InterfaceName string `json:"interfaceName,omitempty"`
Holepunch bool `json:"holepunch,omitempty"`
TlsClientCert string `json:"tlsClientCert,omitempty"`
PingInterval string `json:"pingInterval,omitempty"`
PingTimeout string `json:"pingTimeout,omitempty"`
OrgID string `json:"orgId,omitempty"`
}
// SwitchOrgRequest defines the structure for switching organizations

View File

@@ -8,6 +8,7 @@ import (
"path/filepath"
"runtime"
"strconv"
"strings"
"time"
)
@@ -21,9 +22,11 @@ type OlmConfig struct {
UserToken string `json:"userToken"`
// Network settings
MTU int `json:"mtu"`
DNS string `json:"dns"`
InterfaceName string `json:"interface"`
MTU int `json:"mtu"`
DNS string `json:"dns"`
DNSProxyIP string `json:"dnsProxyIP"`
UpstreamDNS []string `json:"upstreamDNS"`
InterfaceName string `json:"interface"`
// Logging
LogLevel string `json:"logLevel"`
@@ -76,6 +79,8 @@ func DefaultConfig() *OlmConfig {
config := &OlmConfig{
MTU: 1280,
DNS: "8.8.8.8",
DNSProxyIP: "",
UpstreamDNS: []string{"8.8.8.8"},
LogLevel: "INFO",
InterfaceName: "olm",
EnableAPI: false,
@@ -90,6 +95,8 @@ func DefaultConfig() *OlmConfig {
// Track default sources
config.sources["mtu"] = string(SourceDefault)
config.sources["dns"] = string(SourceDefault)
config.sources["dnsProxyIP"] = string(SourceDefault)
config.sources["upstreamDNS"] = string(SourceDefault)
config.sources["logLevel"] = string(SourceDefault)
config.sources["interface"] = string(SourceDefault)
config.sources["enableApi"] = string(SourceDefault)
@@ -213,6 +220,14 @@ func loadConfigFromEnv(config *OlmConfig) {
config.DNS = val
config.sources["dns"] = string(SourceEnv)
}
if val := os.Getenv("DNS_PROXY_IP"); val != "" {
config.DNSProxyIP = val
config.sources["dnsProxyIP"] = string(SourceEnv)
}
if val := os.Getenv("UPSTREAM_DNS"); val != "" {
config.UpstreamDNS = []string{val}
config.sources["upstreamDNS"] = string(SourceEnv)
}
if val := os.Getenv("LOG_LEVEL"); val != "" {
config.LogLevel = val
config.sources["logLevel"] = string(SourceEnv)
@@ -264,6 +279,8 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
"userToken": config.UserToken,
"mtu": config.MTU,
"dns": config.DNS,
"dnsProxyIP": config.DNSProxyIP,
"upstreamDNS": fmt.Sprintf("%v", config.UpstreamDNS),
"logLevel": config.LogLevel,
"interface": config.InterfaceName,
"httpAddr": config.HTTPAddr,
@@ -283,6 +300,9 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
serviceFlags.StringVar(&config.UserToken, "user-token", config.UserToken, "User token (optional)")
serviceFlags.IntVar(&config.MTU, "mtu", config.MTU, "MTU to use")
serviceFlags.StringVar(&config.DNS, "dns", config.DNS, "DNS server to use")
serviceFlags.StringVar(&config.DNSProxyIP, "dns-proxy-ip", config.DNSProxyIP, "IP address for the DNS proxy (required for DNS proxy)")
var upstreamDNSFlag string
serviceFlags.StringVar(&upstreamDNSFlag, "upstream-dns", "", "Upstream DNS server(s) (comma-separated, default: 8.8.8.8)")
serviceFlags.StringVar(&config.LogLevel, "log-level", config.LogLevel, "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
serviceFlags.StringVar(&config.InterfaceName, "interface", config.InterfaceName, "Name of the WireGuard interface")
serviceFlags.StringVar(&config.HTTPAddr, "http-addr", config.HTTPAddr, "HTTP server address (e.g., ':9452')")
@@ -301,6 +321,16 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
return false, false, err
}
// Parse upstream DNS flag if provided
if upstreamDNSFlag != "" {
config.UpstreamDNS = []string{}
for _, dns := range splitComma(upstreamDNSFlag) {
if dns != "" {
config.UpstreamDNS = append(config.UpstreamDNS, dns)
}
}
}
// Track which values were changed by CLI args
if config.Endpoint != origValues["endpoint"].(string) {
config.sources["endpoint"] = string(SourceCLI)
@@ -323,6 +353,12 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
if config.DNS != origValues["dns"].(string) {
config.sources["dns"] = string(SourceCLI)
}
if config.DNSProxyIP != origValues["dnsProxyIP"].(string) {
config.sources["dnsProxyIP"] = string(SourceCLI)
}
if fmt.Sprintf("%v", config.UpstreamDNS) != origValues["upstreamDNS"].(string) {
config.sources["upstreamDNS"] = string(SourceCLI)
}
if config.LogLevel != origValues["logLevel"].(string) {
config.sources["logLevel"] = string(SourceCLI)
}
@@ -418,6 +454,14 @@ func mergeConfigs(dest, src *OlmConfig) {
dest.DNS = src.DNS
dest.sources["dns"] = string(SourceFile)
}
if src.DNSProxyIP != "" {
dest.DNSProxyIP = src.DNSProxyIP
dest.sources["dnsProxyIP"] = string(SourceFile)
}
if len(src.UpstreamDNS) > 0 && fmt.Sprintf("%v", src.UpstreamDNS) != "[8.8.8.8]" {
dest.UpstreamDNS = src.UpstreamDNS
dest.sources["upstreamDNS"] = string(SourceFile)
}
if src.LogLevel != "" && src.LogLevel != "INFO" {
dest.LogLevel = src.LogLevel
dest.sources["logLevel"] = string(SourceFile)
@@ -526,6 +570,8 @@ func (c *OlmConfig) ShowConfig() {
fmt.Println("\nNetwork:")
fmt.Printf(" mtu = %d [%s]\n", c.MTU, getSource("mtu"))
fmt.Printf(" dns = %s [%s]\n", c.DNS, getSource("dns"))
fmt.Printf(" dns-proxy-ip = %s [%s]\n", formatValue("dnsProxyIP", c.DNSProxyIP), getSource("dnsProxyIP"))
fmt.Printf(" upstream-dns = %v [%s]\n", c.UpstreamDNS, getSource("upstreamDNS"))
fmt.Printf(" interface = %s [%s]\n", c.InterfaceName, getSource("interface"))
// Logging
@@ -560,3 +606,16 @@ func (c *OlmConfig) ShowConfig() {
fmt.Println("\nPriority: cli > environment > file > default")
fmt.Println()
}
// splitComma splits a comma-separated string into a slice of trimmed strings
func splitComma(s string) []string {
parts := strings.Split(s, ",")
result := make([]string, 0, len(parts))
for _, part := range parts {
trimmed := strings.TrimSpace(part)
if trimmed != "" {
result = append(result, trimmed)
}
}
return result
}

View File

@@ -25,23 +25,19 @@ import (
)
const (
// DNS proxy listening address
DNSProxyIP = "10.30.30.30"
DNSPort = 53
// Upstream DNS servers
UpstreamDNS1 = "8.8.8.8:53"
UpstreamDNS2 = "8.8.4.4:53"
DNSPort = 53
)
// DNSProxy implements a DNS proxy using gvisor netstack
type DNSProxy struct {
stack *stack.Stack
ep *channel.Endpoint
proxyIP netip.Addr
mtu int
tunDevice tun.Device // Direct reference to underlying TUN device for responses
recordStore *DNSRecordStore // Local DNS records
stack *stack.Stack
ep *channel.Endpoint
proxyIP netip.Addr
upstreamDNS []string
mtu int
tunDevice tun.Device // Direct reference to underlying TUN device for responses
middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering
recordStore *DNSRecordStore // Local DNS records
ctx context.Context
cancel context.CancelFunc
@@ -49,12 +45,16 @@ type DNSProxy struct {
}
// NewDNSProxy creates a new DNS proxy
func NewDNSProxy(tunDevice tun.Device, mtu int) (*DNSProxy, error) {
proxyIP, err := netip.ParseAddr(DNSProxyIP)
func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu int, dnsProxyIP string, upstreamDns []string) (*DNSProxy, error) {
proxyIP, err := netip.ParseAddr(dnsProxyIP)
if err != nil {
return nil, fmt.Errorf("invalid proxy IP: %w", err)
}
if len(upstreamDns) == 0 {
return nil, fmt.Errorf("at least one upstream DNS server must be specified")
}
ctx, cancel := context.WithCancel(context.Background())
proxy := &DNSProxy{
@@ -82,9 +82,11 @@ func NewDNSProxy(tunDevice tun.Device, mtu int) (*DNSProxy, error) {
}
// Add IP address
// Parse the proxy IP to get the octets
ipBytes := proxyIP.As4()
protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddrFrom4([4]byte{10, 30, 30, 30}).WithPrefix(),
AddressWithPrefix: tcpip.AddrFrom4(ipBytes).WithPrefix(),
}
if err := proxy.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil {
@@ -101,23 +103,23 @@ func NewDNSProxy(tunDevice tun.Device, mtu int) (*DNSProxy, error) {
}
// Start starts the DNS proxy and registers with the filter
func (p *DNSProxy) Start(device *device.MiddleDevice) error {
func (p *DNSProxy) Start() error {
// Install packet filter rule
device.AddRule(p.proxyIP, p.handlePacket)
p.middleDevice.AddRule(p.proxyIP, p.handlePacket)
// Start DNS listener
p.wg.Add(2)
go p.runDNSListener()
go p.runPacketSender()
logger.Info("DNS proxy started on %s:%d", DNSProxyIP, DNSPort)
logger.Info("DNS proxy started on %s:%d", p.proxyIP.String(), DNSPort)
return nil
}
// Stop stops the DNS proxy
func (p *DNSProxy) Stop(device *device.MiddleDevice) {
if device != nil {
device.RemoveRule(p.proxyIP)
func (p *DNSProxy) Stop() {
if p.middleDevice != nil {
p.middleDevice.RemoveRule(p.proxyIP)
}
p.cancel()
p.wg.Wait()
@@ -174,9 +176,11 @@ func (p *DNSProxy) runDNSListener() {
defer p.wg.Done()
// Create UDP listener using gonet
// Parse the proxy IP to get the octets
ipBytes := p.proxyIP.As4()
laddr := &tcpip.FullAddress{
NIC: 1,
Addr: tcpip.AddrFrom4([4]byte{10, 30, 30, 30}),
Addr: tcpip.AddrFrom4(ipBytes),
Port: DNSPort,
}
@@ -322,11 +326,11 @@ func (p *DNSProxy) checkLocalRecords(query *dns.Msg, question dns.Question) *dns
// forwardToUpstream forwards a DNS query to upstream DNS servers
func (p *DNSProxy) forwardToUpstream(query *dns.Msg) *dns.Msg {
// Try primary DNS server
response, err := p.queryUpstream(UpstreamDNS1, query, 2*time.Second)
if err != nil {
response, err := p.queryUpstream(p.upstreamDNS[0], query, 2*time.Second)
if err != nil && len(p.upstreamDNS) > 1 {
// Try secondary DNS server
logger.Debug("Primary DNS failed, trying secondary: %v", err)
response, err = p.queryUpstream(UpstreamDNS2, query, 2*time.Second)
response, err = p.queryUpstream(p.upstreamDNS[1], query, 2*time.Second)
if err != nil {
logger.Error("Both DNS servers failed: %v", err)
return nil

View File

@@ -226,6 +226,8 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
UserToken: config.UserToken,
MTU: config.MTU,
DNS: config.DNS,
DNSProxyIP: config.DNSProxyIP,
UpstreamDNS: config.UpstreamDNS,
InterfaceName: config.InterfaceName,
Holepunch: config.Holepunch,
TlsClientCert: config.TlsClientCert,

View File

@@ -47,6 +47,8 @@ type TunnelConfig struct {
// Network settings
MTU int
DNS string
DNSProxyIP string
UpstreamDNS []string
InterfaceName string
// Advanced
@@ -124,6 +126,8 @@ func Init(ctx context.Context, config GlobalConfig) {
UserToken: req.UserToken,
MTU: req.MTU,
DNS: req.DNS,
DNSProxyIP: req.DNSProxyIP,
UpstreamDNS: req.UpstreamDNS,
InterfaceName: req.InterfaceName,
Holepunch: req.Holepunch,
TlsClientCert: req.TlsClientCert,
@@ -157,6 +161,11 @@ func Init(ctx context.Context, config GlobalConfig) {
if req.DNS == "" {
tunnelConfig.DNS = "9.9.9.9"
}
// DNSProxyIP has no default - it must be provided if DNS proxy is desired
// UpstreamDNS defaults to 8.8.8.8 if not provided
if len(req.UpstreamDNS) == 0 {
tunnelConfig.UpstreamDNS = []string{"8.8.8.8"}
}
if req.InterfaceName == "" {
tunnelConfig.InterfaceName = "olm"
}
@@ -473,25 +482,26 @@ func StartTunnel(config TunnelConfig) {
logger.Error("Failed to bring up WireGuard device: %v", err)
}
// Create and start DNS proxy
dnsProxy, err = dns.NewDNSProxy(tdev, config.MTU)
if err != nil {
logger.Error("Failed to create DNS proxy: %v", err)
}
if err := dnsProxy.Start(middleDev); err != nil {
logger.Error("Failed to start DNS proxy: %v", err)
}
ip := net.ParseIP("192.168.1.100")
if dnsProxy.AddDNSRecord("example.com", ip); err != nil {
logger.Error("Failed to add DNS record: %v", err)
if config.DNSProxyIP != "" {
// Create and start DNS proxy
dnsProxy, err = dns.NewDNSProxy(tdev, middleDev, config.MTU, config.DNSProxyIP, config.UpstreamDNS)
if err != nil {
logger.Error("Failed to create DNS proxy: %v", err)
}
if err := dnsProxy.Start(); err != nil {
logger.Error("Failed to start DNS proxy: %v", err)
}
}
if err = ConfigureInterface(interfaceName, wgData, config.MTU); err != nil {
logger.Error("Failed to configure interface: %v", err)
}
if addRoutes([]string{"10.30.30.30/32"}, interfaceName); err != nil {
logger.Error("Failed to add route for DNS server: %v", err)
if config.DNSProxyIP != "" {
if addRoutes([]string{config.DNSProxyIP + "/32"}, interfaceName); err != nil {
logger.Error("Failed to add route for DNS server: %v", err)
}
}
// TODO: seperate adding the callback to this so we can init it above with the interface
@@ -661,22 +671,12 @@ func StartTunnel(config TunnelConfig) {
return
}
var addData AddPeerData
if err := json.Unmarshal(jsonData, &addData); err != nil {
var siteConfig SiteConfig
if err := json.Unmarshal(jsonData, &siteConfig); err != nil {
logger.Error("Error unmarshaling add data: %v", err)
return
}
// Convert to SiteConfig
siteConfig := SiteConfig{
SiteId: addData.SiteId,
Endpoint: addData.Endpoint,
PublicKey: addData.PublicKey,
ServerIP: addData.ServerIP,
ServerPort: addData.ServerPort,
RemoteSubnets: addData.RemoteSubnets,
}
// Add the peer to WireGuard
if dev == nil {
logger.Error("WireGuard device not initialized")
@@ -699,7 +699,7 @@ func StartTunnel(config TunnelConfig) {
}
// Add successful
logger.Info("Successfully added peer for site %d", addData.SiteId)
logger.Info("Successfully added peer for site %d", siteConfig.SiteId)
// Update WgData with the new peer
wgData.Sites = append(wgData.Sites, siteConfig)
@@ -1076,7 +1076,7 @@ func Close() {
// Stop DNS proxy
if dnsProxy != nil {
dnsProxy.Stop(middleDev)
dnsProxy.Stop()
dnsProxy = nil
}

View File

@@ -1,17 +1,9 @@
package olm
type WgData struct {
Sites []SiteConfig `json:"sites"`
TunnelIP string `json:"tunnelIP"`
}
type SiteConfig struct {
SiteId int `json:"siteId"`
Endpoint string `json:"endpoint"`
PublicKey string `json:"publicKey"`
ServerIP string `json:"serverIP"`
ServerPort uint16 `json:"serverPort"`
RemoteSubnets []string `json:"remoteSubnets,omitempty"` // optional, array of subnets that this site can access
Sites []SiteConfig `json:"sites"`
TunnelIP string `json:"tunnelIP"`
UtilitySubnet string `json:"utilitySubnet"` // this is for things like the DNS server, and alias addresses
}
type HolePunchMessage struct {
@@ -40,23 +32,19 @@ type PeerAction struct {
}
// UpdatePeerData represents the data needed to update a peer
type UpdatePeerData struct {
type SiteConfig struct {
SiteId int `json:"siteId"`
Endpoint string `json:"endpoint,omitempty"`
PublicKey string `json:"publicKey,omitempty"`
ServerIP string `json:"serverIP,omitempty"`
ServerPort uint16 `json:"serverPort,omitempty"`
RemoteSubnets []string `json:"remoteSubnets,omitempty"` // optional, array of subnets that this site can access
Aliases []Alias `json:"aliases,omitempty"` // optional, array of alias configurations
}
// AddPeerData represents the data needed to add a peer
type AddPeerData struct {
SiteId int `json:"siteId"`
Endpoint string `json:"endpoint"`
PublicKey string `json:"publicKey"`
ServerIP string `json:"serverIP"`
ServerPort uint16 `json:"serverPort"`
RemoteSubnets []string `json:"remoteSubnets,omitempty"` // optional, array of subnets that this site can access
type Alias struct {
Alias string `json:"alias"` // the alias name
AliasAddress string `json:"aliasAddress"` // the alias IP address
}
// RemovePeerData represents the data needed to remove a peer