Files
netbird/client/internal/dns/service_memory.go

165 lines
3.6 KiB
Go

package dns
import (
"errors"
"fmt"
"net/netip"
"sync"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface"
nbnet "github.com/netbirdio/netbird/client/net"
)
type ServiceViaMemory struct {
wgInterface WGIface
dnsMux *dns.ServeMux
runtimeIP netip.Addr
runtimePort int
tcpDNS *tcpDNSServer
tcpHookSet bool
listenerIsRunning bool
listenerFlagLock sync.Mutex
}
func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
lastIP, err := nbnet.GetLastIPFromNetwork(wgIface.Address().Network, 1)
if err != nil {
log.Errorf("get last ip from network: %v", err)
}
return &ServiceViaMemory{
wgInterface: wgIface,
dnsMux: dns.NewServeMux(),
runtimeIP: lastIP,
runtimePort: DefaultPort,
}
}
func (s *ServiceViaMemory) Listen() error {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if s.listenerIsRunning {
return nil
}
if err := s.filterDNSTraffic(); err != nil {
return fmt.Errorf("filter dns traffic: %w", err)
}
s.listenerIsRunning = true
log.Debugf("dns service listening on: %s", s.RuntimeIP())
return nil
}
func (s *ServiceViaMemory) Stop() error {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if !s.listenerIsRunning {
return nil
}
filter := s.wgInterface.GetFilter()
if filter != nil {
filter.SetUDPPacketHook(s.runtimeIP, uint16(s.runtimePort), nil)
if s.tcpHookSet {
filter.SetTCPPacketHook(s.runtimeIP, uint16(s.runtimePort), nil)
}
}
if s.tcpDNS != nil {
s.tcpDNS.Stop()
}
s.listenerIsRunning = false
return nil
}
func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
s.dnsMux.Handle(pattern, handler)
}
func (s *ServiceViaMemory) DeregisterMux(pattern string) {
s.dnsMux.HandleRemove(pattern)
}
func (s *ServiceViaMemory) RuntimePort() int {
return s.runtimePort
}
func (s *ServiceViaMemory) RuntimeIP() netip.Addr {
return s.runtimeIP
}
func (s *ServiceViaMemory) filterDNSTraffic() error {
filter := s.wgInterface.GetFilter()
if filter == nil {
return errors.New("DNS filter not initialized")
}
// Create TCP DNS server lazily here since the device may not exist at construction time.
if s.tcpDNS == nil {
if dev := s.wgInterface.GetDevice(); dev != nil {
// MTU only affects TCP segment sizing; DNS messages are small so this has no practical impact.
s.tcpDNS = newTCPDNSServer(s.dnsMux, dev.Device, s.runtimeIP, uint16(s.runtimePort), iface.DefaultMTU)
}
}
firstLayerDecoder := layers.LayerTypeIPv4
if s.wgInterface.Address().IP.Is6() {
firstLayerDecoder = layers.LayerTypeIPv6
}
hook := func(packetData []byte) bool {
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
udpLayer := packet.Layer(layers.LayerTypeUDP)
if udpLayer == nil {
return true
}
udp, ok := udpLayer.(*layers.UDP)
if !ok {
return true
}
msg := new(dns.Msg)
if err := msg.Unpack(udp.Payload); err != nil {
log.Tracef("parse DNS request: %v", err)
return true
}
dev := s.wgInterface.GetDevice()
if dev == nil {
return true
}
writer := &responseWriter{
remote: remoteAddrFromPacket(packet),
packet: packet,
device: dev.Device,
}
go s.dnsMux.ServeDNS(writer, msg)
return true
}
filter.SetUDPPacketHook(s.runtimeIP, uint16(s.runtimePort), hook)
if s.tcpDNS != nil {
tcpHook := func(packetData []byte) bool {
s.tcpDNS.InjectPacket(packetData)
return true
}
filter.SetTCPPacketHook(s.runtimeIP, uint16(s.runtimePort), tcpHook)
s.tcpHookSet = true
}
return nil
}