Feat fake dns address (#902)

Works only with userspace implementation:
1. Configure host to solve DNS requests via a fake DSN server address in the Netbird network.
2. Add to firewall catch rule for these DNS requests.
3. Resolve these DNS requests and respond by writing directly to wireguard device.
This commit is contained in:
Givi Khojanashvili
2023-06-08 13:46:57 +04:00
committed by GitHub
parent 2c9583dfe1
commit 1d9feab2d9
17 changed files with 721 additions and 57 deletions

View File

@@ -0,0 +1,103 @@
package dns
import (
"fmt"
"net"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/miekg/dns"
"golang.zx2c4.com/wireguard/tun"
)
type responseWriter struct {
local net.Addr
remote net.Addr
packet gopacket.Packet
device tun.Device
}
// LocalAddr returns the net.Addr of the server
func (r *responseWriter) LocalAddr() net.Addr {
return r.local
}
// RemoteAddr returns the net.Addr of the client that sent the current request.
func (r *responseWriter) RemoteAddr() net.Addr {
return r.remote
}
// WriteMsg writes a reply back to the client.
func (r *responseWriter) WriteMsg(msg *dns.Msg) error {
buff, err := msg.Pack()
if err != nil {
return err
}
_, err = r.Write(buff)
return err
}
// Write writes a raw buffer back to the client.
func (r *responseWriter) Write(data []byte) (int, error) {
var ip gopacket.SerializableLayer
// Get the UDP layer
udpLayer := r.packet.Layer(layers.LayerTypeUDP)
udp := udpLayer.(*layers.UDP)
// Swap the source and destination addresses for the response
udp.SrcPort, udp.DstPort = udp.DstPort, udp.SrcPort
// Check if it's an IPv4 packet
if ipv4Layer := r.packet.Layer(layers.LayerTypeIPv4); ipv4Layer != nil {
ipv4 := ipv4Layer.(*layers.IPv4)
ipv4.SrcIP, ipv4.DstIP = ipv4.DstIP, ipv4.SrcIP
ip = ipv4
} else if ipv6Layer := r.packet.Layer(layers.LayerTypeIPv6); ipv6Layer != nil {
ipv6 := ipv6Layer.(*layers.IPv6)
ipv6.SrcIP, ipv6.DstIP = ipv6.DstIP, ipv6.SrcIP
ip = ipv6
}
if err := udp.SetNetworkLayerForChecksum(ip.(gopacket.NetworkLayer)); err != nil {
return 0, fmt.Errorf("failed to set network layer for checksum: %v", err)
}
// Serialize the packet
buffer := gopacket.NewSerializeBuffer()
options := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
payload := gopacket.Payload(data)
err := gopacket.SerializeLayers(buffer, options, ip, udp, payload)
if err != nil {
return 0, fmt.Errorf("failed to serialize packet: %v", err)
}
send := buffer.Bytes()
sendBuffer := make([]byte, 40, len(send)+40)
sendBuffer = append(sendBuffer, send...)
return r.device.Write([][]byte{sendBuffer}, 40)
}
// Close closes the connection.
func (r *responseWriter) Close() error {
return nil
}
// TsigStatus returns the status of the Tsig.
func (r *responseWriter) TsigStatus() error {
return nil
}
// TsigTimersOnly sets the tsig timers only boolean.
func (r *responseWriter) TsigTimersOnly(bool) {
}
// Hijack lets the caller take over the connection.
// After a call to Hijack(), the DNS package will not do anything with the connection.
func (r *responseWriter) Hijack() {
}

View File

@@ -0,0 +1,93 @@
package dns
import (
"net"
"testing"
"github.com/golang/mock/gomock"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/miekg/dns"
"github.com/netbirdio/netbird/iface/mocks"
)
func TestResponseWriterLocalAddr(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
device := mocks.NewMockDevice(ctrl)
device.EXPECT().Write(gomock.Any(), gomock.Any())
request := &dns.Msg{
Question: []dns.Question{{
Name: "google.com.",
Qtype: dns.TypeA,
Qclass: dns.TypeA,
}},
}
replyMessage := &dns.Msg{}
replyMessage.SetReply(request)
replyMessage.RecursionAvailable = true
replyMessage.Rcode = dns.RcodeSuccess
replyMessage.Answer = []dns.RR{
&dns.A{
A: net.IPv4(8, 8, 8, 8),
},
}
ipv4 := &layers.IPv4{
Protocol: layers.IPProtocolUDP,
SrcIP: net.IPv4(127, 0, 0, 1),
DstIP: net.IPv4(127, 0, 0, 2),
}
udp := &layers.UDP{
DstPort: 53,
SrcPort: 45223,
}
if err := udp.SetNetworkLayerForChecksum(ipv4); err != nil {
t.Error("failed to set network layer for checksum")
return
}
// Serialize the packet
buffer := gopacket.NewSerializeBuffer()
options := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
requestData, err := request.Pack()
if err != nil {
t.Errorf("got an error while packing the request message, error: %v", err)
return
}
payload := gopacket.Payload(requestData)
if err := gopacket.SerializeLayers(buffer, options, ipv4, udp, payload); err != nil {
t.Errorf("failed to serialize packet: %v", err)
return
}
rw := &responseWriter{
local: &net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 55223,
},
remote: &net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 53,
},
packet: gopacket.NewPacket(
buffer.Bytes(),
layers.LayerTypeIPv4,
gopacket.Default,
),
device: device,
}
if err := rw.WriteMsg(replyMessage); err != nil {
t.Errorf("got an error while writing the local resolver response, error: %v", err)
return
}
}

View File

@@ -5,12 +5,15 @@ package dns
import (
"context"
"fmt"
"math/big"
"net"
"net/netip"
"runtime"
"sync"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/miekg/dns"
"github.com/mitchellh/hashstructure/v2"
log "github.com/sirupsen/logrus"
@@ -33,6 +36,7 @@ type DefaultServer struct {
ctx context.Context
ctxCancel context.CancelFunc
mux sync.Mutex
fakeResolverWG sync.WaitGroup
server *dns.Server
dnsMux *dns.ServeMux
dnsMuxMap registeredHandlerMap
@@ -105,6 +109,25 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd
// Start runs the listener in a go routine
func (s *DefaultServer) Start() {
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
s.runtimeIP = getLastIPFromNetwork(s.wgInterface.Address().Network, 1)
s.runtimePort = 53
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
s.fakeResolverWG.Add(1)
go func() {
s.setListenerStatus(true)
defer s.setListenerStatus(false)
hookID := s.filterDNSTraffic()
s.fakeResolverWG.Wait()
if err := s.wgInterface.GetFilter().RemovePacketHook(hookID); err != nil {
log.Errorf("unable to remove DNS packet hook: %s", err)
}
}()
return
}
if s.customAddress != nil {
s.runtimeIP = s.customAddress.Addr().String()
s.runtimePort = int(s.customAddress.Port())
@@ -172,6 +195,10 @@ func (s *DefaultServer) Stop() {
log.Error(err)
}
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning {
s.fakeResolverWG.Done()
}
err = s.stopListener()
if err != nil {
log.Error(err)
@@ -235,12 +262,15 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
}
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
// is the service should be disabled, we stop the listener
// is the service should be disabled, we stop the listener or fake resolver
// and proceed with a regular update to clean up the handlers and records
if !update.ServiceEnable {
err := s.stopListener()
if err != nil {
log.Error(err)
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning {
s.fakeResolverWG.Done()
} else {
if err := s.stopListener(); err != nil {
log.Error(err)
}
}
} else if !s.listenerIsRunning {
s.Start()
@@ -477,3 +507,59 @@ func (s *DefaultServer) upstreamCallbacks(
}
return
}
func (s *DefaultServer) filterDNSTraffic() string {
filter := s.wgInterface.GetFilter()
if filter == nil {
log.Error("can't set DNS filter, filter not initialized")
return ""
}
firstLayerDecoder := layers.LayerTypeIPv4
if s.wgInterface.Address().Network.IP.To4() == nil {
firstLayerDecoder = layers.LayerTypeIPv6
}
hook := func(packetData []byte) bool {
// Decode the packet
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
// Get the UDP layer
udpLayer := packet.Layer(layers.LayerTypeUDP)
udp := udpLayer.(*layers.UDP)
msg := new(dns.Msg)
if err := msg.Unpack(udp.Payload); err != nil {
log.Tracef("parse DNS request: %v", err)
return true
}
writer := responseWriter{
packet: packet,
device: s.wgInterface.GetDevice().Device,
}
go s.dnsMux.ServeDNS(&writer, msg)
return true
}
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook)
}
func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string {
// Calculate the last IP in the CIDR range
var endIP net.IP
for i := 0; i < len(network.IP); i++ {
endIP = append(endIP, network.IP[i]|^network.Mask[i])
}
// convert to big.Int
endInt := big.NewInt(0)
endInt.SetBytes(endIP)
// subtract fromEnd from the last ip
fromEndBig := big.NewInt(int64(fromEnd))
resultInt := big.NewInt(0)
resultInt.Sub(endInt, fromEndBig)
return net.IP(resultInt.Bytes()).String()
}

View File

@@ -0,0 +1,31 @@
package dns
import (
"net"
"testing"
)
func TestGetLastIPFromNetwork(t *testing.T) {
tests := []struct {
addr string
ip string
}{
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
{"192.168.0.0/30", "192.168.0.2"},
{"192.168.0.0/16", "192.168.255.254"},
{"192.168.0.0/24", "192.168.0.254"},
}
for _, tt := range tests {
_, ipnet, err := net.ParseCIDR(tt.addr)
if err != nil {
t.Errorf("Error parsing CIDR: %v", err)
return
}
lastIP := getLastIPFromNetwork(ipnet, 1)
if lastIP != tt.ip {
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
}
}
}

View File

@@ -9,10 +9,9 @@ import (
"testing"
"time"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/miekg/dns"
"github.com/netbirdio/netbird/client/internal/stdnet"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
)
@@ -238,6 +237,7 @@ func TestUpdateDNSServer(t *testing.T) {
dnsServer.updateSerial = testCase.initSerial
// pretend we are running
dnsServer.listenerIsRunning = true
dnsServer.fakeResolverWG.Add(1)
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
if err != nil {