mirror of
https://github.com/fosrl/olm.git
synced 2026-03-02 00:36:44 +00:00
13
config.go
13
config.go
@@ -43,6 +43,7 @@ type OlmConfig struct {
|
|||||||
DisableHolepunch bool `json:"disableHolepunch"`
|
DisableHolepunch bool `json:"disableHolepunch"`
|
||||||
TlsClientCert string `json:"tlsClientCert"`
|
TlsClientCert string `json:"tlsClientCert"`
|
||||||
OverrideDNS bool `json:"overrideDNS"`
|
OverrideDNS bool `json:"overrideDNS"`
|
||||||
|
TunnelDNS bool `json:"tunnelDNS"`
|
||||||
DisableRelay bool `json:"disableRelay"`
|
DisableRelay bool `json:"disableRelay"`
|
||||||
// DoNotCreateNewClient bool `json:"doNotCreateNewClient"`
|
// DoNotCreateNewClient bool `json:"doNotCreateNewClient"`
|
||||||
|
|
||||||
@@ -88,6 +89,7 @@ func DefaultConfig() *OlmConfig {
|
|||||||
PingInterval: "3s",
|
PingInterval: "3s",
|
||||||
PingTimeout: "5s",
|
PingTimeout: "5s",
|
||||||
DisableHolepunch: false,
|
DisableHolepunch: false,
|
||||||
|
TunnelDNS: false,
|
||||||
// DoNotCreateNewClient: false,
|
// DoNotCreateNewClient: false,
|
||||||
sources: make(map[string]string),
|
sources: make(map[string]string),
|
||||||
}
|
}
|
||||||
@@ -105,6 +107,7 @@ func DefaultConfig() *OlmConfig {
|
|||||||
config.sources["pingTimeout"] = string(SourceDefault)
|
config.sources["pingTimeout"] = string(SourceDefault)
|
||||||
config.sources["disableHolepunch"] = string(SourceDefault)
|
config.sources["disableHolepunch"] = string(SourceDefault)
|
||||||
config.sources["overrideDNS"] = string(SourceDefault)
|
config.sources["overrideDNS"] = string(SourceDefault)
|
||||||
|
config.sources["tunnelDNS"] = string(SourceDefault)
|
||||||
config.sources["disableRelay"] = string(SourceDefault)
|
config.sources["disableRelay"] = string(SourceDefault)
|
||||||
// config.sources["doNotCreateNewClient"] = string(SourceDefault)
|
// config.sources["doNotCreateNewClient"] = string(SourceDefault)
|
||||||
|
|
||||||
@@ -265,6 +268,10 @@ func loadConfigFromEnv(config *OlmConfig) {
|
|||||||
config.DisableRelay = true
|
config.DisableRelay = true
|
||||||
config.sources["disableRelay"] = string(SourceEnv)
|
config.sources["disableRelay"] = string(SourceEnv)
|
||||||
}
|
}
|
||||||
|
if val := os.Getenv("TUNNEL_DNS"); val == "true" {
|
||||||
|
config.TunnelDNS = true
|
||||||
|
config.sources["tunnelDNS"] = string(SourceEnv)
|
||||||
|
}
|
||||||
// if val := os.Getenv("DO_NOT_CREATE_NEW_CLIENT"); val == "true" {
|
// if val := os.Getenv("DO_NOT_CREATE_NEW_CLIENT"); val == "true" {
|
||||||
// config.DoNotCreateNewClient = true
|
// config.DoNotCreateNewClient = true
|
||||||
// config.sources["doNotCreateNewClient"] = string(SourceEnv)
|
// config.sources["doNotCreateNewClient"] = string(SourceEnv)
|
||||||
@@ -295,6 +302,7 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
|
|||||||
"disableHolepunch": config.DisableHolepunch,
|
"disableHolepunch": config.DisableHolepunch,
|
||||||
"overrideDNS": config.OverrideDNS,
|
"overrideDNS": config.OverrideDNS,
|
||||||
"disableRelay": config.DisableRelay,
|
"disableRelay": config.DisableRelay,
|
||||||
|
"tunnelDNS": config.TunnelDNS,
|
||||||
// "doNotCreateNewClient": config.DoNotCreateNewClient,
|
// "doNotCreateNewClient": config.DoNotCreateNewClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -318,6 +326,7 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
|
|||||||
serviceFlags.BoolVar(&config.DisableHolepunch, "disable-holepunch", config.DisableHolepunch, "Disable hole punching")
|
serviceFlags.BoolVar(&config.DisableHolepunch, "disable-holepunch", config.DisableHolepunch, "Disable hole punching")
|
||||||
serviceFlags.BoolVar(&config.OverrideDNS, "override-dns", config.OverrideDNS, "Override system DNS settings")
|
serviceFlags.BoolVar(&config.OverrideDNS, "override-dns", config.OverrideDNS, "Override system DNS settings")
|
||||||
serviceFlags.BoolVar(&config.DisableRelay, "disable-relay", config.DisableRelay, "Disable relay connections")
|
serviceFlags.BoolVar(&config.DisableRelay, "disable-relay", config.DisableRelay, "Disable relay connections")
|
||||||
|
serviceFlags.BoolVar(&config.TunnelDNS, "tunnel-dns", config.TunnelDNS, "Use tunnel for DNS traffic")
|
||||||
// serviceFlags.BoolVar(&config.DoNotCreateNewClient, "do-not-create-new-client", config.DoNotCreateNewClient, "Do not create new client")
|
// serviceFlags.BoolVar(&config.DoNotCreateNewClient, "do-not-create-new-client", config.DoNotCreateNewClient, "Do not create new client")
|
||||||
|
|
||||||
version := serviceFlags.Bool("version", false, "Print the version")
|
version := serviceFlags.Bool("version", false, "Print the version")
|
||||||
@@ -393,6 +402,9 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
|
|||||||
if config.DisableRelay != origValues["disableRelay"].(bool) {
|
if config.DisableRelay != origValues["disableRelay"].(bool) {
|
||||||
config.sources["disableRelay"] = string(SourceCLI)
|
config.sources["disableRelay"] = string(SourceCLI)
|
||||||
}
|
}
|
||||||
|
if config.TunnelDNS != origValues["tunnelDNS"].(bool) {
|
||||||
|
config.sources["tunnelDNS"] = string(SourceCLI)
|
||||||
|
}
|
||||||
// if config.DoNotCreateNewClient != origValues["doNotCreateNewClient"].(bool) {
|
// if config.DoNotCreateNewClient != origValues["doNotCreateNewClient"].(bool) {
|
||||||
// config.sources["doNotCreateNewClient"] = string(SourceCLI)
|
// config.sources["doNotCreateNewClient"] = string(SourceCLI)
|
||||||
// }
|
// }
|
||||||
@@ -606,6 +618,7 @@ func (c *OlmConfig) ShowConfig() {
|
|||||||
fmt.Println("\nAdvanced:")
|
fmt.Println("\nAdvanced:")
|
||||||
fmt.Printf(" disable-holepunch = %v [%s]\n", c.DisableHolepunch, getSource("disableHolepunch"))
|
fmt.Printf(" disable-holepunch = %v [%s]\n", c.DisableHolepunch, getSource("disableHolepunch"))
|
||||||
fmt.Printf(" override-dns = %v [%s]\n", c.OverrideDNS, getSource("overrideDNS"))
|
fmt.Printf(" override-dns = %v [%s]\n", c.OverrideDNS, getSource("overrideDNS"))
|
||||||
|
fmt.Printf(" tunnel-dns = %v [%s]\n", c.TunnelDNS, getSource("tunnelDNS"))
|
||||||
fmt.Printf(" disable-relay = %v [%s]\n", c.DisableRelay, getSource("disableRelay"))
|
fmt.Printf(" disable-relay = %v [%s]\n", c.DisableRelay, getSource("disableRelay"))
|
||||||
// fmt.Printf(" do-not-create-new-client = %v [%s]\n", c.DoNotCreateNewClient, getSource("doNotCreateNewClient"))
|
// fmt.Printf(" do-not-create-new-client = %v [%s]\n", c.DoNotCreateNewClient, getSource("doNotCreateNewClient"))
|
||||||
if c.TlsClientCert != "" {
|
if c.TlsClientCert != "" {
|
||||||
|
|||||||
313
dns/dns_proxy.go
313
dns/dns_proxy.go
@@ -34,18 +34,26 @@ type DNSProxy struct {
|
|||||||
ep *channel.Endpoint
|
ep *channel.Endpoint
|
||||||
proxyIP netip.Addr
|
proxyIP netip.Addr
|
||||||
upstreamDNS []string
|
upstreamDNS []string
|
||||||
|
tunnelDNS bool // Whether to tunnel DNS queries over WireGuard or to spit them out locally
|
||||||
mtu int
|
mtu int
|
||||||
tunDevice tun.Device // Direct reference to underlying TUN device for responses
|
tunDevice tun.Device // Direct reference to underlying TUN device for responses
|
||||||
middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering
|
middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering
|
||||||
recordStore *DNSRecordStore // Local DNS records
|
recordStore *DNSRecordStore // Local DNS records
|
||||||
|
|
||||||
|
// Tunnel DNS fields - for sending queries over WireGuard
|
||||||
|
tunnelIP netip.Addr // WireGuard interface IP (source for tunneled queries)
|
||||||
|
tunnelStack *stack.Stack // Separate netstack for outbound tunnel queries
|
||||||
|
tunnelEp *channel.Endpoint
|
||||||
|
tunnelActivePorts map[uint16]bool
|
||||||
|
tunnelPortsLock sync.Mutex
|
||||||
|
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDNSProxy creates a new DNS proxy
|
// NewDNSProxy creates a new DNS proxy
|
||||||
func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string) (*DNSProxy, error) {
|
func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string, tunnelDns bool, tunnelIP string) (*DNSProxy, error) {
|
||||||
proxyIP, err := PickIPFromSubnet(utilitySubnet)
|
proxyIP, err := PickIPFromSubnet(utilitySubnet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to pick DNS proxy IP from subnet: %v", err)
|
return nil, fmt.Errorf("failed to pick DNS proxy IP from subnet: %v", err)
|
||||||
@@ -58,17 +66,28 @@ func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu in
|
|||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
proxy := &DNSProxy{
|
proxy := &DNSProxy{
|
||||||
proxyIP: proxyIP,
|
proxyIP: proxyIP,
|
||||||
mtu: mtu,
|
mtu: mtu,
|
||||||
tunDevice: tunDevice,
|
tunDevice: tunDevice,
|
||||||
middleDevice: middleDevice,
|
middleDevice: middleDevice,
|
||||||
upstreamDNS: upstreamDns,
|
upstreamDNS: upstreamDns,
|
||||||
recordStore: NewDNSRecordStore(),
|
tunnelDNS: tunnelDns,
|
||||||
ctx: ctx,
|
recordStore: NewDNSRecordStore(),
|
||||||
cancel: cancel,
|
tunnelActivePorts: make(map[uint16]bool),
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create gvisor netstack
|
// Parse tunnel IP if provided (needed for tunneled DNS)
|
||||||
|
if tunnelIP != "" {
|
||||||
|
addr, err := netip.ParseAddr(tunnelIP)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse tunnel IP: %v", err)
|
||||||
|
}
|
||||||
|
proxy.tunnelIP = addr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create gvisor netstack for receiving DNS queries
|
||||||
stackOpts := stack.Options{
|
stackOpts := stack.Options{
|
||||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
||||||
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
|
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
|
||||||
@@ -101,9 +120,104 @@ func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu in
|
|||||||
NIC: 1,
|
NIC: 1,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Initialize tunnel netstack if tunnel DNS is enabled
|
||||||
|
if tunnelDns {
|
||||||
|
if !proxy.tunnelIP.IsValid() {
|
||||||
|
return nil, fmt.Errorf("tunnel IP is required when tunnelDNS is enabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: DO WE NEED TO ESTABLISH ANOTHER NETSTACK HERE OR CAN WE COMBINE WITH WGTESTER?
|
||||||
|
if err := proxy.initTunnelNetstack(); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to initialize tunnel netstack: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return proxy, nil
|
return proxy, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// initTunnelNetstack creates a separate netstack for outbound DNS queries through the tunnel
|
||||||
|
func (p *DNSProxy) initTunnelNetstack() error {
|
||||||
|
// Create gvisor netstack for outbound tunnel queries
|
||||||
|
stackOpts := stack.Options{
|
||||||
|
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
||||||
|
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
|
||||||
|
HandleLocal: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
p.tunnelEp = channel.New(256, uint32(p.mtu), "")
|
||||||
|
p.tunnelStack = stack.New(stackOpts)
|
||||||
|
|
||||||
|
// Create NIC
|
||||||
|
if err := p.tunnelStack.CreateNIC(1, p.tunnelEp); err != nil {
|
||||||
|
return fmt.Errorf("failed to create tunnel NIC: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add tunnel IP address (WireGuard interface IP)
|
||||||
|
ipBytes := p.tunnelIP.As4()
|
||||||
|
protoAddr := tcpip.ProtocolAddress{
|
||||||
|
Protocol: ipv4.ProtocolNumber,
|
||||||
|
AddressWithPrefix: tcpip.AddrFrom4(ipBytes).WithPrefix(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p.tunnelStack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil {
|
||||||
|
return fmt.Errorf("failed to add tunnel protocol address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add default route
|
||||||
|
p.tunnelStack.AddRoute(tcpip.Route{
|
||||||
|
Destination: header.IPv4EmptySubnet,
|
||||||
|
NIC: 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Register filter rule on MiddleDevice to intercept responses
|
||||||
|
p.middleDevice.AddRule(p.tunnelIP, p.handleTunnelResponse)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleTunnelResponse handles packets coming back from the tunnel destined for the tunnel IP
|
||||||
|
func (p *DNSProxy) handleTunnelResponse(packet []byte) bool {
|
||||||
|
// Check if it's UDP
|
||||||
|
proto, ok := util.GetProtocol(packet)
|
||||||
|
if !ok || proto != 17 { // UDP
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check destination port - should be one of our active outbound ports
|
||||||
|
port, ok := util.GetDestPort(packet)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we are expecting a response on this port
|
||||||
|
p.tunnelPortsLock.Lock()
|
||||||
|
active := p.tunnelActivePorts[uint16(port)]
|
||||||
|
p.tunnelPortsLock.Unlock()
|
||||||
|
|
||||||
|
if !active {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inject into tunnel netstack
|
||||||
|
version := packet[0] >> 4
|
||||||
|
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||||
|
Payload: buffer.MakeWithData(packet),
|
||||||
|
})
|
||||||
|
|
||||||
|
switch version {
|
||||||
|
case 4:
|
||||||
|
p.tunnelEp.InjectInbound(ipv4.ProtocolNumber, pkb)
|
||||||
|
case 6:
|
||||||
|
p.tunnelEp.InjectInbound(ipv6.ProtocolNumber, pkb)
|
||||||
|
default:
|
||||||
|
pkb.DecRef()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
pkb.DecRef()
|
||||||
|
return true // Handled
|
||||||
|
}
|
||||||
|
|
||||||
// Start starts the DNS proxy and registers with the filter
|
// Start starts the DNS proxy and registers with the filter
|
||||||
func (p *DNSProxy) Start() error {
|
func (p *DNSProxy) Start() error {
|
||||||
// Install packet filter rule
|
// Install packet filter rule
|
||||||
@@ -114,7 +228,13 @@ func (p *DNSProxy) Start() error {
|
|||||||
go p.runDNSListener()
|
go p.runDNSListener()
|
||||||
go p.runPacketSender()
|
go p.runPacketSender()
|
||||||
|
|
||||||
logger.Info("DNS proxy started on %s:%d", p.proxyIP.String(), DNSPort)
|
// Start tunnel packet sender if tunnel DNS is enabled
|
||||||
|
if p.tunnelDNS {
|
||||||
|
p.wg.Add(1)
|
||||||
|
go p.runTunnelPacketSender()
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("DNS proxy started on %s:%d (tunnelDNS=%v)", p.proxyIP.String(), DNSPort, p.tunnelDNS)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -122,6 +242,9 @@ func (p *DNSProxy) Start() error {
|
|||||||
func (p *DNSProxy) Stop() {
|
func (p *DNSProxy) Stop() {
|
||||||
if p.middleDevice != nil {
|
if p.middleDevice != nil {
|
||||||
p.middleDevice.RemoveRule(p.proxyIP)
|
p.middleDevice.RemoveRule(p.proxyIP)
|
||||||
|
if p.tunnelDNS && p.tunnelIP.IsValid() {
|
||||||
|
p.middleDevice.RemoveRule(p.tunnelIP)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
p.cancel()
|
p.cancel()
|
||||||
|
|
||||||
@@ -130,12 +253,21 @@ func (p *DNSProxy) Stop() {
|
|||||||
p.ep.Close()
|
p.ep.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close tunnel endpoint if it exists
|
||||||
|
if p.tunnelEp != nil {
|
||||||
|
p.tunnelEp.Close()
|
||||||
|
}
|
||||||
|
|
||||||
p.wg.Wait()
|
p.wg.Wait()
|
||||||
|
|
||||||
if p.stack != nil {
|
if p.stack != nil {
|
||||||
p.stack.Close()
|
p.stack.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if p.tunnelStack != nil {
|
||||||
|
p.tunnelStack.Close()
|
||||||
|
}
|
||||||
|
|
||||||
logger.Info("DNS proxy stopped")
|
logger.Info("DNS proxy stopped")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -348,8 +480,16 @@ func (p *DNSProxy) forwardToUpstream(query *dns.Msg) *dns.Msg {
|
|||||||
return response
|
return response
|
||||||
}
|
}
|
||||||
|
|
||||||
// queryUpstream sends a DNS query to upstream server using miekg/dns
|
// queryUpstream sends a DNS query to upstream server
|
||||||
func (p *DNSProxy) queryUpstream(server string, query *dns.Msg, timeout time.Duration) (*dns.Msg, error) {
|
func (p *DNSProxy) queryUpstream(server string, query *dns.Msg, timeout time.Duration) (*dns.Msg, error) {
|
||||||
|
if p.tunnelDNS {
|
||||||
|
return p.queryUpstreamTunnel(server, query, timeout)
|
||||||
|
}
|
||||||
|
return p.queryUpstreamDirect(server, query, timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// queryUpstreamDirect sends a DNS query to upstream server using miekg/dns directly (host networking)
|
||||||
|
func (p *DNSProxy) queryUpstreamDirect(server string, query *dns.Msg, timeout time.Duration) (*dns.Msg, error) {
|
||||||
client := &dns.Client{
|
client := &dns.Client{
|
||||||
Timeout: timeout,
|
Timeout: timeout,
|
||||||
}
|
}
|
||||||
@@ -362,6 +502,155 @@ func (p *DNSProxy) queryUpstream(server string, query *dns.Msg, timeout time.Dur
|
|||||||
return response, nil
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// queryUpstreamTunnel sends a DNS query through the WireGuard tunnel
|
||||||
|
func (p *DNSProxy) queryUpstreamTunnel(server string, query *dns.Msg, timeout time.Duration) (*dns.Msg, error) {
|
||||||
|
// Dial through the tunnel netstack
|
||||||
|
conn, port, err := p.dialTunnel("udp", server)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to dial tunnel: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
conn.Close()
|
||||||
|
p.removeTunnelPort(port)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Pack the query
|
||||||
|
queryData, err := query.Pack()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to pack query: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set deadline
|
||||||
|
conn.SetDeadline(time.Now().Add(timeout))
|
||||||
|
|
||||||
|
// Send the query
|
||||||
|
_, err = conn.Write(queryData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to send query: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the response
|
||||||
|
buf := make([]byte, 4096)
|
||||||
|
n, err := conn.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the response
|
||||||
|
response := new(dns.Msg)
|
||||||
|
if err := response.Unpack(buf[:n]); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unpack response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// dialTunnel creates a UDP connection through the tunnel netstack
|
||||||
|
func (p *DNSProxy) dialTunnel(network, addr string) (net.Conn, uint16, error) {
|
||||||
|
if p.tunnelStack == nil {
|
||||||
|
return nil, 0, fmt.Errorf("tunnel netstack not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse remote address
|
||||||
|
raddr, err := net.ResolveUDPAddr("udp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use tunnel IP as source
|
||||||
|
ipBytes := p.tunnelIP.As4()
|
||||||
|
|
||||||
|
// Create UDP connection with ephemeral port
|
||||||
|
laddr := &tcpip.FullAddress{
|
||||||
|
NIC: 1,
|
||||||
|
Addr: tcpip.AddrFrom4(ipBytes),
|
||||||
|
Port: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
raddrTcpip := &tcpip.FullAddress{
|
||||||
|
NIC: 1,
|
||||||
|
Addr: tcpip.AddrFrom4([4]byte(raddr.IP.To4())),
|
||||||
|
Port: uint16(raddr.Port),
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := gonet.DialUDP(p.tunnelStack, laddr, raddrTcpip, ipv4.ProtocolNumber)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get local port
|
||||||
|
localAddr := conn.LocalAddr().(*net.UDPAddr)
|
||||||
|
port := uint16(localAddr.Port)
|
||||||
|
|
||||||
|
// Register port so we can receive responses
|
||||||
|
p.tunnelPortsLock.Lock()
|
||||||
|
p.tunnelActivePorts[port] = true
|
||||||
|
p.tunnelPortsLock.Unlock()
|
||||||
|
|
||||||
|
return conn, port, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeTunnelPort removes a port from the active ports map
|
||||||
|
func (p *DNSProxy) removeTunnelPort(port uint16) {
|
||||||
|
p.tunnelPortsLock.Lock()
|
||||||
|
delete(p.tunnelActivePorts, port)
|
||||||
|
p.tunnelPortsLock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// runTunnelPacketSender reads packets from tunnel netstack and injects them into WireGuard
|
||||||
|
func (p *DNSProxy) runTunnelPacketSender() {
|
||||||
|
defer p.wg.Done()
|
||||||
|
logger.Debug("DNS tunnel packet sender goroutine started")
|
||||||
|
|
||||||
|
ticker := time.NewTicker(1 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-p.ctx.Done():
|
||||||
|
logger.Debug("DNS tunnel packet sender exiting")
|
||||||
|
// Drain any remaining packets
|
||||||
|
for {
|
||||||
|
pkt := p.tunnelEp.Read()
|
||||||
|
if pkt == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
pkt.DecRef()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
// Try to read packets
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
pkt := p.tunnelEp.Read()
|
||||||
|
if pkt == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract packet data
|
||||||
|
slices := pkt.AsSlices()
|
||||||
|
if len(slices) > 0 {
|
||||||
|
var totalSize int
|
||||||
|
for _, slice := range slices {
|
||||||
|
totalSize += len(slice)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, totalSize)
|
||||||
|
pos := 0
|
||||||
|
for _, slice := range slices {
|
||||||
|
copy(buf[pos:], slice)
|
||||||
|
pos += len(slice)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inject into MiddleDevice (outbound to WG)
|
||||||
|
p.middleDevice.InjectOutbound(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
pkt.DecRef()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// runPacketSender sends packets from netstack back to TUN
|
// runPacketSender sends packets from netstack back to TUN
|
||||||
func (p *DNSProxy) runPacketSender() {
|
func (p *DNSProxy) runPacketSender() {
|
||||||
defer p.wg.Done()
|
defer p.wg.Done()
|
||||||
|
|||||||
14
olm/olm.go
14
olm/olm.go
@@ -374,8 +374,14 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
logger.Error("Failed to bring up WireGuard device: %v", err)
|
logger.Error("Failed to bring up WireGuard device: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extract interface IP (strip CIDR notation if present)
|
||||||
|
interfaceIP := wgData.TunnelIP
|
||||||
|
if strings.Contains(interfaceIP, "/") {
|
||||||
|
interfaceIP = strings.Split(interfaceIP, "/")[0]
|
||||||
|
}
|
||||||
|
|
||||||
// Create and start DNS proxy
|
// Create and start DNS proxy
|
||||||
dnsProxy, err = dns.NewDNSProxy(tdev, middleDev, config.MTU, wgData.UtilitySubnet, config.UpstreamDNS)
|
dnsProxy, err = dns.NewDNSProxy(tdev, middleDev, config.MTU, wgData.UtilitySubnet, config.UpstreamDNS, config.TunnelDNS, interfaceIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to create DNS proxy: %v", err)
|
logger.Error("Failed to create DNS proxy: %v", err)
|
||||||
}
|
}
|
||||||
@@ -388,12 +394,6 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
logger.Error("Failed to add route for utility subnet: %v", err)
|
logger.Error("Failed to add route for utility subnet: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: seperate adding the callback to this so we can init it above with the interface
|
|
||||||
interfaceIP := wgData.TunnelIP
|
|
||||||
if strings.Contains(interfaceIP, "/") {
|
|
||||||
interfaceIP = strings.Split(interfaceIP, "/")[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create peer manager with integrated peer monitoring
|
// Create peer manager with integrated peer monitoring
|
||||||
peerManager = peers.NewPeerManager(peers.PeerManagerConfig{
|
peerManager = peers.NewPeerManager(peers.PeerManagerConfig{
|
||||||
Device: dev,
|
Device: dev,
|
||||||
|
|||||||
@@ -61,6 +61,7 @@ type TunnelConfig struct {
|
|||||||
EnableUAPI bool
|
EnableUAPI bool
|
||||||
|
|
||||||
OverrideDNS bool
|
OverrideDNS bool
|
||||||
|
TunnelDNS bool
|
||||||
|
|
||||||
DisableRelay bool
|
DisableRelay bool
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user