mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
165 lines
3.6 KiB
Go
165 lines
3.6 KiB
Go
package dns
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"net/netip"
|
|
"sync"
|
|
|
|
"github.com/google/gopacket"
|
|
"github.com/google/gopacket/layers"
|
|
"github.com/miekg/dns"
|
|
log "github.com/sirupsen/logrus"
|
|
|
|
"github.com/netbirdio/netbird/client/iface"
|
|
nbnet "github.com/netbirdio/netbird/client/net"
|
|
)
|
|
|
|
type ServiceViaMemory struct {
|
|
wgInterface WGIface
|
|
dnsMux *dns.ServeMux
|
|
runtimeIP netip.Addr
|
|
runtimePort int
|
|
tcpDNS *tcpDNSServer
|
|
tcpHookSet bool
|
|
listenerIsRunning bool
|
|
listenerFlagLock sync.Mutex
|
|
}
|
|
|
|
func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
|
|
lastIP, err := nbnet.GetLastIPFromNetwork(wgIface.Address().Network, 1)
|
|
if err != nil {
|
|
log.Errorf("get last ip from network: %v", err)
|
|
}
|
|
|
|
return &ServiceViaMemory{
|
|
wgInterface: wgIface,
|
|
dnsMux: dns.NewServeMux(),
|
|
runtimeIP: lastIP,
|
|
runtimePort: DefaultPort,
|
|
}
|
|
}
|
|
|
|
func (s *ServiceViaMemory) Listen() error {
|
|
s.listenerFlagLock.Lock()
|
|
defer s.listenerFlagLock.Unlock()
|
|
|
|
if s.listenerIsRunning {
|
|
return nil
|
|
}
|
|
|
|
if err := s.filterDNSTraffic(); err != nil {
|
|
return fmt.Errorf("filter dns traffic: %w", err)
|
|
}
|
|
s.listenerIsRunning = true
|
|
|
|
log.Debugf("dns service listening on: %s", s.RuntimeIP())
|
|
return nil
|
|
}
|
|
|
|
func (s *ServiceViaMemory) Stop() error {
|
|
s.listenerFlagLock.Lock()
|
|
defer s.listenerFlagLock.Unlock()
|
|
|
|
if !s.listenerIsRunning {
|
|
return nil
|
|
}
|
|
|
|
filter := s.wgInterface.GetFilter()
|
|
if filter != nil {
|
|
filter.SetUDPPacketHook(s.runtimeIP, uint16(s.runtimePort), nil)
|
|
if s.tcpHookSet {
|
|
filter.SetTCPPacketHook(s.runtimeIP, uint16(s.runtimePort), nil)
|
|
}
|
|
}
|
|
|
|
if s.tcpDNS != nil {
|
|
s.tcpDNS.Stop()
|
|
}
|
|
|
|
s.listenerIsRunning = false
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
|
|
s.dnsMux.Handle(pattern, handler)
|
|
}
|
|
|
|
func (s *ServiceViaMemory) DeregisterMux(pattern string) {
|
|
s.dnsMux.HandleRemove(pattern)
|
|
}
|
|
|
|
func (s *ServiceViaMemory) RuntimePort() int {
|
|
return s.runtimePort
|
|
}
|
|
|
|
func (s *ServiceViaMemory) RuntimeIP() netip.Addr {
|
|
return s.runtimeIP
|
|
}
|
|
|
|
func (s *ServiceViaMemory) filterDNSTraffic() error {
|
|
filter := s.wgInterface.GetFilter()
|
|
if filter == nil {
|
|
return errors.New("DNS filter not initialized")
|
|
}
|
|
|
|
// Create TCP DNS server lazily here since the device may not exist at construction time.
|
|
if s.tcpDNS == nil {
|
|
if dev := s.wgInterface.GetDevice(); dev != nil {
|
|
// MTU only affects TCP segment sizing; DNS messages are small so this has no practical impact.
|
|
s.tcpDNS = newTCPDNSServer(s.dnsMux, dev.Device, s.runtimeIP, uint16(s.runtimePort), iface.DefaultMTU)
|
|
}
|
|
}
|
|
|
|
firstLayerDecoder := layers.LayerTypeIPv4
|
|
if s.wgInterface.Address().IP.Is6() {
|
|
firstLayerDecoder = layers.LayerTypeIPv6
|
|
}
|
|
|
|
hook := func(packetData []byte) bool {
|
|
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
|
|
|
|
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
|
if udpLayer == nil {
|
|
return true
|
|
}
|
|
udp, ok := udpLayer.(*layers.UDP)
|
|
if !ok {
|
|
return true
|
|
}
|
|
|
|
msg := new(dns.Msg)
|
|
if err := msg.Unpack(udp.Payload); err != nil {
|
|
log.Tracef("parse DNS request: %v", err)
|
|
return true
|
|
}
|
|
|
|
dev := s.wgInterface.GetDevice()
|
|
if dev == nil {
|
|
return true
|
|
}
|
|
|
|
writer := &responseWriter{
|
|
remote: remoteAddrFromPacket(packet),
|
|
packet: packet,
|
|
device: dev.Device,
|
|
}
|
|
go s.dnsMux.ServeDNS(writer, msg)
|
|
return true
|
|
}
|
|
|
|
filter.SetUDPPacketHook(s.runtimeIP, uint16(s.runtimePort), hook)
|
|
|
|
if s.tcpDNS != nil {
|
|
tcpHook := func(packetData []byte) bool {
|
|
s.tcpDNS.InjectPacket(packetData)
|
|
return true
|
|
}
|
|
filter.SetTCPPacketHook(s.runtimeIP, uint16(s.runtimePort), tcpHook)
|
|
s.tcpHookSet = true
|
|
}
|
|
|
|
return nil
|
|
}
|