mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
445 lines
11 KiB
Go
445 lines
11 KiB
Go
package dns
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/netip"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/miekg/dns"
|
|
log "github.com/sirupsen/logrus"
|
|
"golang.zx2c4.com/wireguard/tun"
|
|
"gvisor.dev/gvisor/pkg/buffer"
|
|
"gvisor.dev/gvisor/pkg/tcpip"
|
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
|
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
|
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
|
"gvisor.dev/gvisor/pkg/waiter"
|
|
)
|
|
|
|
const (
|
|
dnsTCPReceiveWindow = 8192
|
|
dnsTCPMaxInFlight = 16
|
|
dnsTCPIdleTimeout = 30 * time.Second
|
|
dnsTCPReadTimeout = 5 * time.Second
|
|
)
|
|
|
|
// tcpDNSServer is an on-demand TCP DNS server backed by a minimal gvisor stack.
|
|
// It is started lazily when a truncated DNS response is detected and shuts down
|
|
// after a period of inactivity to conserve resources.
|
|
type tcpDNSServer struct {
|
|
mu sync.Mutex
|
|
s *stack.Stack
|
|
ep *dnsEndpoint
|
|
mux *dns.ServeMux
|
|
tunDev tun.Device
|
|
ip netip.Addr
|
|
port uint16
|
|
mtu uint16
|
|
|
|
running bool
|
|
closed bool
|
|
timerID uint64
|
|
timer *time.Timer
|
|
}
|
|
|
|
func newTCPDNSServer(mux *dns.ServeMux, tunDev tun.Device, ip netip.Addr, port uint16, mtu uint16) *tcpDNSServer {
|
|
return &tcpDNSServer{
|
|
mux: mux,
|
|
tunDev: tunDev,
|
|
ip: ip,
|
|
port: port,
|
|
mtu: mtu,
|
|
}
|
|
}
|
|
|
|
// InjectPacket ensures the stack is running and delivers a raw IP packet into
|
|
// the gvisor stack for TCP processing. Combining both operations under a single
|
|
// lock prevents a race where the idle timer could stop the stack between
|
|
// start and delivery.
|
|
func (t *tcpDNSServer) InjectPacket(payload []byte) {
|
|
t.mu.Lock()
|
|
defer t.mu.Unlock()
|
|
|
|
if t.closed {
|
|
return
|
|
}
|
|
|
|
if !t.running {
|
|
if err := t.startLocked(); err != nil {
|
|
log.Errorf("failed to start TCP DNS stack: %v", err)
|
|
return
|
|
}
|
|
t.running = true
|
|
log.Debugf("TCP DNS stack started on %s:%d (triggered by %s)", t.ip, t.port, srcAddrFromPacket(payload))
|
|
}
|
|
t.resetTimerLocked()
|
|
|
|
ep := t.ep
|
|
if ep == nil || ep.dispatcher == nil {
|
|
return
|
|
}
|
|
|
|
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
|
Payload: buffer.MakeWithData(payload),
|
|
})
|
|
// DeliverNetworkPacket takes ownership of the packet buffer; do not DecRef.
|
|
ep.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
|
|
}
|
|
|
|
// Stop tears down the gvisor stack and releases resources permanently.
|
|
// After Stop, InjectPacket becomes a no-op.
|
|
func (t *tcpDNSServer) Stop() {
|
|
t.mu.Lock()
|
|
defer t.mu.Unlock()
|
|
|
|
t.stopLocked()
|
|
t.closed = true
|
|
}
|
|
|
|
func (t *tcpDNSServer) startLocked() error {
|
|
// TODO: add ipv6.NewProtocol when IPv6 overlay support lands.
|
|
s := stack.New(stack.Options{
|
|
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
|
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
|
|
HandleLocal: false,
|
|
})
|
|
|
|
nicID := tcpip.NICID(1)
|
|
ep := &dnsEndpoint{
|
|
tunDev: t.tunDev,
|
|
}
|
|
ep.mtu.Store(uint32(t.mtu))
|
|
|
|
if err := s.CreateNIC(nicID, ep); err != nil {
|
|
s.Close()
|
|
s.Wait()
|
|
return fmt.Errorf("create NIC: %v", err)
|
|
}
|
|
|
|
protoAddr := tcpip.ProtocolAddress{
|
|
Protocol: ipv4.ProtocolNumber,
|
|
AddressWithPrefix: tcpip.AddressWithPrefix{
|
|
Address: tcpip.AddrFromSlice(t.ip.AsSlice()),
|
|
PrefixLen: 32,
|
|
},
|
|
}
|
|
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
|
|
s.Close()
|
|
s.Wait()
|
|
return fmt.Errorf("add protocol address: %s", err)
|
|
}
|
|
|
|
if err := s.SetPromiscuousMode(nicID, true); err != nil {
|
|
s.Close()
|
|
s.Wait()
|
|
return fmt.Errorf("set promiscuous mode: %s", err)
|
|
}
|
|
if err := s.SetSpoofing(nicID, true); err != nil {
|
|
s.Close()
|
|
s.Wait()
|
|
return fmt.Errorf("set spoofing: %s", err)
|
|
}
|
|
|
|
defaultSubnet, err := tcpip.NewSubnet(
|
|
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
|
|
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
|
|
)
|
|
if err != nil {
|
|
s.Close()
|
|
s.Wait()
|
|
return fmt.Errorf("create default subnet: %w", err)
|
|
}
|
|
|
|
s.SetRouteTable([]tcpip.Route{
|
|
{Destination: defaultSubnet, NIC: nicID},
|
|
})
|
|
|
|
tcpFwd := tcp.NewForwarder(s, dnsTCPReceiveWindow, dnsTCPMaxInFlight, func(r *tcp.ForwarderRequest) {
|
|
t.handleTCPDNS(r)
|
|
})
|
|
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket)
|
|
|
|
t.s = s
|
|
t.ep = ep
|
|
return nil
|
|
}
|
|
|
|
func (t *tcpDNSServer) stopLocked() {
|
|
if !t.running {
|
|
return
|
|
}
|
|
|
|
if t.timer != nil {
|
|
t.timer.Stop()
|
|
t.timer = nil
|
|
}
|
|
|
|
if t.s != nil {
|
|
t.s.Close()
|
|
t.s.Wait()
|
|
t.s = nil
|
|
}
|
|
t.ep = nil
|
|
t.running = false
|
|
|
|
log.Debugf("TCP DNS stack stopped")
|
|
}
|
|
|
|
func (t *tcpDNSServer) resetTimerLocked() {
|
|
if t.timer != nil {
|
|
t.timer.Stop()
|
|
}
|
|
t.timerID++
|
|
id := t.timerID
|
|
t.timer = time.AfterFunc(dnsTCPIdleTimeout, func() {
|
|
t.mu.Lock()
|
|
defer t.mu.Unlock()
|
|
|
|
// Only stop if this timer is still the active one.
|
|
// A racing InjectPacket may have replaced it.
|
|
if t.timerID != id {
|
|
return
|
|
}
|
|
t.stopLocked()
|
|
})
|
|
}
|
|
|
|
func (t *tcpDNSServer) handleTCPDNS(r *tcp.ForwarderRequest) {
|
|
id := r.ID()
|
|
|
|
wq := waiter.Queue{}
|
|
ep, epErr := r.CreateEndpoint(&wq)
|
|
if epErr != nil {
|
|
log.Debugf("TCP DNS: failed to create endpoint: %v", epErr)
|
|
r.Complete(true)
|
|
return
|
|
}
|
|
r.Complete(false)
|
|
|
|
conn := gonet.NewTCPConn(&wq, ep)
|
|
defer func() {
|
|
if err := conn.Close(); err != nil {
|
|
log.Tracef("TCP DNS: close conn: %v", err)
|
|
}
|
|
}()
|
|
|
|
// Reset idle timer on activity
|
|
t.mu.Lock()
|
|
t.resetTimerLocked()
|
|
t.mu.Unlock()
|
|
|
|
localAddr := &net.TCPAddr{
|
|
IP: id.LocalAddress.AsSlice(),
|
|
Port: int(id.LocalPort),
|
|
}
|
|
remoteAddr := &net.TCPAddr{
|
|
IP: id.RemoteAddress.AsSlice(),
|
|
Port: int(id.RemotePort),
|
|
}
|
|
|
|
for {
|
|
if err := conn.SetReadDeadline(time.Now().Add(dnsTCPReadTimeout)); err != nil {
|
|
log.Debugf("TCP DNS: set deadline for %s: %v", remoteAddr, err)
|
|
break
|
|
}
|
|
|
|
msg, err := readTCPDNSMessage(conn)
|
|
if err != nil {
|
|
if !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
|
log.Debugf("TCP DNS: read from %s: %v", remoteAddr, err)
|
|
}
|
|
break
|
|
}
|
|
|
|
writer := &tcpResponseWriter{
|
|
conn: conn,
|
|
localAddr: localAddr,
|
|
remoteAddr: remoteAddr,
|
|
}
|
|
t.mux.ServeDNS(writer, msg)
|
|
}
|
|
}
|
|
|
|
// dnsEndpoint implements stack.LinkEndpoint for writing packets back via the tun device.
|
|
type dnsEndpoint struct {
|
|
dispatcher stack.NetworkDispatcher
|
|
tunDev tun.Device
|
|
mtu atomic.Uint32
|
|
}
|
|
|
|
func (e *dnsEndpoint) Attach(dispatcher stack.NetworkDispatcher) { e.dispatcher = dispatcher }
|
|
func (e *dnsEndpoint) IsAttached() bool { return e.dispatcher != nil }
|
|
func (e *dnsEndpoint) MTU() uint32 { return e.mtu.Load() }
|
|
func (e *dnsEndpoint) Capabilities() stack.LinkEndpointCapabilities { return stack.CapabilityNone }
|
|
func (e *dnsEndpoint) MaxHeaderLength() uint16 { return 0 }
|
|
func (e *dnsEndpoint) LinkAddress() tcpip.LinkAddress { return "" }
|
|
func (e *dnsEndpoint) Wait() { /* no async work */ }
|
|
func (e *dnsEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone }
|
|
func (e *dnsEndpoint) AddHeader(*stack.PacketBuffer) { /* IP-level endpoint, no link header */ }
|
|
func (e *dnsEndpoint) ParseHeader(*stack.PacketBuffer) bool { return true }
|
|
func (e *dnsEndpoint) Close() { /* lifecycle managed by tcpDNSServer */ }
|
|
func (e *dnsEndpoint) SetLinkAddress(tcpip.LinkAddress) { /* no link address for tun */ }
|
|
func (e *dnsEndpoint) SetMTU(mtu uint32) { e.mtu.Store(mtu) }
|
|
func (e *dnsEndpoint) SetOnCloseAction(func()) { /* not needed */ }
|
|
|
|
const tunPacketOffset = 40
|
|
|
|
func (e *dnsEndpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
|
|
var written int
|
|
for _, pkt := range pkts.AsSlice() {
|
|
data := stack.PayloadSince(pkt.NetworkHeader())
|
|
if data == nil {
|
|
continue
|
|
}
|
|
|
|
raw := data.AsSlice()
|
|
buf := make([]byte, tunPacketOffset, tunPacketOffset+len(raw))
|
|
buf = append(buf, raw...)
|
|
data.Release()
|
|
|
|
if _, err := e.tunDev.Write([][]byte{buf}, tunPacketOffset); err != nil {
|
|
log.Tracef("TCP DNS endpoint: failed to write packet: %v", err)
|
|
continue
|
|
}
|
|
written++
|
|
}
|
|
return written, nil
|
|
}
|
|
|
|
// tcpResponseWriter implements dns.ResponseWriter for TCP DNS connections.
|
|
type tcpResponseWriter struct {
|
|
conn *gonet.TCPConn
|
|
localAddr net.Addr
|
|
remoteAddr net.Addr
|
|
}
|
|
|
|
func (w *tcpResponseWriter) LocalAddr() net.Addr {
|
|
return w.localAddr
|
|
}
|
|
|
|
func (w *tcpResponseWriter) RemoteAddr() net.Addr {
|
|
return w.remoteAddr
|
|
}
|
|
|
|
func (w *tcpResponseWriter) WriteMsg(msg *dns.Msg) error {
|
|
data, err := msg.Pack()
|
|
if err != nil {
|
|
return fmt.Errorf("pack: %w", err)
|
|
}
|
|
|
|
// DNS TCP: 2-byte length prefix + message
|
|
buf := make([]byte, 2+len(data))
|
|
buf[0] = byte(len(data) >> 8)
|
|
buf[1] = byte(len(data))
|
|
copy(buf[2:], data)
|
|
|
|
if _, err = w.conn.Write(buf); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (w *tcpResponseWriter) Write(data []byte) (int, error) {
|
|
buf := make([]byte, 2+len(data))
|
|
buf[0] = byte(len(data) >> 8)
|
|
buf[1] = byte(len(data))
|
|
copy(buf[2:], data)
|
|
if _, err := w.conn.Write(buf); err != nil {
|
|
return 0, err
|
|
}
|
|
return len(data), nil
|
|
}
|
|
|
|
func (w *tcpResponseWriter) Close() error {
|
|
return w.conn.Close()
|
|
}
|
|
|
|
func (w *tcpResponseWriter) TsigStatus() error { return nil }
|
|
func (w *tcpResponseWriter) TsigTimersOnly(bool) { /* TSIG not supported */ }
|
|
func (w *tcpResponseWriter) Hijack() { /* not supported */ }
|
|
|
|
// readTCPDNSMessage reads a single DNS message from a TCP connection (length-prefixed).
|
|
func readTCPDNSMessage(conn *gonet.TCPConn) (*dns.Msg, error) {
|
|
// DNS over TCP uses a 2-byte length prefix
|
|
lenBuf := make([]byte, 2)
|
|
if _, err := io.ReadFull(conn, lenBuf); err != nil {
|
|
return nil, fmt.Errorf("read length: %w", err)
|
|
}
|
|
|
|
msgLen := int(lenBuf[0])<<8 | int(lenBuf[1])
|
|
if msgLen == 0 || msgLen > 65535 {
|
|
return nil, fmt.Errorf("invalid message length: %d", msgLen)
|
|
}
|
|
|
|
msgBuf := make([]byte, msgLen)
|
|
if _, err := io.ReadFull(conn, msgBuf); err != nil {
|
|
return nil, fmt.Errorf("read message: %w", err)
|
|
}
|
|
|
|
msg := new(dns.Msg)
|
|
if err := msg.Unpack(msgBuf); err != nil {
|
|
return nil, fmt.Errorf("unpack: %w", err)
|
|
}
|
|
return msg, nil
|
|
}
|
|
|
|
// srcAddrFromPacket extracts the source IP:port from a raw IP+TCP packet for logging.
|
|
// Supports both IPv4 and IPv6.
|
|
func srcAddrFromPacket(pkt []byte) netip.AddrPort {
|
|
if len(pkt) == 0 {
|
|
return netip.AddrPort{}
|
|
}
|
|
|
|
srcIP, transportOffset := srcIPFromPacket(pkt)
|
|
if !srcIP.IsValid() || len(pkt) < transportOffset+2 {
|
|
return netip.AddrPort{}
|
|
}
|
|
|
|
srcPort := uint16(pkt[transportOffset])<<8 | uint16(pkt[transportOffset+1])
|
|
return netip.AddrPortFrom(srcIP.Unmap(), srcPort)
|
|
}
|
|
|
|
func srcIPFromPacket(pkt []byte) (netip.Addr, int) {
|
|
switch header.IPVersion(pkt) {
|
|
case 4:
|
|
return srcIPv4(pkt)
|
|
case 6:
|
|
return srcIPv6(pkt)
|
|
default:
|
|
return netip.Addr{}, 0
|
|
}
|
|
}
|
|
|
|
func srcIPv4(pkt []byte) (netip.Addr, int) {
|
|
if len(pkt) < header.IPv4MinimumSize {
|
|
return netip.Addr{}, 0
|
|
}
|
|
hdr := header.IPv4(pkt)
|
|
src := hdr.SourceAddress()
|
|
ip, ok := netip.AddrFromSlice(src.AsSlice())
|
|
if !ok {
|
|
return netip.Addr{}, 0
|
|
}
|
|
return ip, int(hdr.HeaderLength())
|
|
}
|
|
|
|
func srcIPv6(pkt []byte) (netip.Addr, int) {
|
|
if len(pkt) < header.IPv6MinimumSize {
|
|
return netip.Addr{}, 0
|
|
}
|
|
hdr := header.IPv6(pkt)
|
|
src := hdr.SourceAddress()
|
|
ip, ok := netip.AddrFromSlice(src.AsSlice())
|
|
if !ok {
|
|
return netip.Addr{}, 0
|
|
}
|
|
return ip, header.IPv6MinimumSize
|
|
}
|