mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 16:26:38 +00:00
Feat fake dns address (#902)
Works only with userspace implementation: 1. Configure host to solve DNS requests via a fake DSN server address in the Netbird network. 2. Add to firewall catch rule for these DNS requests. 3. Resolve these DNS requests and respond by writing directly to wireguard device.
This commit is contained in:
committed by
GitHub
parent
2c9583dfe1
commit
1d9feab2d9
103
client/internal/dns/response_writer.go
Normal file
103
client/internal/dns/response_writer.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/miekg/dns"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
type responseWriter struct {
|
||||
local net.Addr
|
||||
remote net.Addr
|
||||
packet gopacket.Packet
|
||||
device tun.Device
|
||||
}
|
||||
|
||||
// LocalAddr returns the net.Addr of the server
|
||||
func (r *responseWriter) LocalAddr() net.Addr {
|
||||
return r.local
|
||||
}
|
||||
|
||||
// RemoteAddr returns the net.Addr of the client that sent the current request.
|
||||
func (r *responseWriter) RemoteAddr() net.Addr {
|
||||
return r.remote
|
||||
}
|
||||
|
||||
// WriteMsg writes a reply back to the client.
|
||||
func (r *responseWriter) WriteMsg(msg *dns.Msg) error {
|
||||
buff, err := msg.Pack()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = r.Write(buff)
|
||||
return err
|
||||
}
|
||||
|
||||
// Write writes a raw buffer back to the client.
|
||||
func (r *responseWriter) Write(data []byte) (int, error) {
|
||||
var ip gopacket.SerializableLayer
|
||||
|
||||
// Get the UDP layer
|
||||
udpLayer := r.packet.Layer(layers.LayerTypeUDP)
|
||||
udp := udpLayer.(*layers.UDP)
|
||||
|
||||
// Swap the source and destination addresses for the response
|
||||
udp.SrcPort, udp.DstPort = udp.DstPort, udp.SrcPort
|
||||
|
||||
// Check if it's an IPv4 packet
|
||||
if ipv4Layer := r.packet.Layer(layers.LayerTypeIPv4); ipv4Layer != nil {
|
||||
ipv4 := ipv4Layer.(*layers.IPv4)
|
||||
ipv4.SrcIP, ipv4.DstIP = ipv4.DstIP, ipv4.SrcIP
|
||||
ip = ipv4
|
||||
} else if ipv6Layer := r.packet.Layer(layers.LayerTypeIPv6); ipv6Layer != nil {
|
||||
ipv6 := ipv6Layer.(*layers.IPv6)
|
||||
ipv6.SrcIP, ipv6.DstIP = ipv6.DstIP, ipv6.SrcIP
|
||||
ip = ipv6
|
||||
}
|
||||
|
||||
if err := udp.SetNetworkLayerForChecksum(ip.(gopacket.NetworkLayer)); err != nil {
|
||||
return 0, fmt.Errorf("failed to set network layer for checksum: %v", err)
|
||||
}
|
||||
|
||||
// Serialize the packet
|
||||
buffer := gopacket.NewSerializeBuffer()
|
||||
options := gopacket.SerializeOptions{
|
||||
ComputeChecksums: true,
|
||||
FixLengths: true,
|
||||
}
|
||||
|
||||
payload := gopacket.Payload(data)
|
||||
err := gopacket.SerializeLayers(buffer, options, ip, udp, payload)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to serialize packet: %v", err)
|
||||
}
|
||||
|
||||
send := buffer.Bytes()
|
||||
sendBuffer := make([]byte, 40, len(send)+40)
|
||||
sendBuffer = append(sendBuffer, send...)
|
||||
|
||||
return r.device.Write([][]byte{sendBuffer}, 40)
|
||||
}
|
||||
|
||||
// Close closes the connection.
|
||||
func (r *responseWriter) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TsigStatus returns the status of the Tsig.
|
||||
func (r *responseWriter) TsigStatus() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TsigTimersOnly sets the tsig timers only boolean.
|
||||
func (r *responseWriter) TsigTimersOnly(bool) {
|
||||
}
|
||||
|
||||
// Hijack lets the caller take over the connection.
|
||||
// After a call to Hijack(), the DNS package will not do anything with the connection.
|
||||
func (r *responseWriter) Hijack() {
|
||||
}
|
||||
93
client/internal/dns/response_writer_test.go
Normal file
93
client/internal/dns/response_writer_test.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/netbirdio/netbird/iface/mocks"
|
||||
)
|
||||
|
||||
func TestResponseWriterLocalAddr(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
device := mocks.NewMockDevice(ctrl)
|
||||
device.EXPECT().Write(gomock.Any(), gomock.Any())
|
||||
|
||||
request := &dns.Msg{
|
||||
Question: []dns.Question{{
|
||||
Name: "google.com.",
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.TypeA,
|
||||
}},
|
||||
}
|
||||
|
||||
replyMessage := &dns.Msg{}
|
||||
replyMessage.SetReply(request)
|
||||
replyMessage.RecursionAvailable = true
|
||||
replyMessage.Rcode = dns.RcodeSuccess
|
||||
replyMessage.Answer = []dns.RR{
|
||||
&dns.A{
|
||||
A: net.IPv4(8, 8, 8, 8),
|
||||
},
|
||||
}
|
||||
|
||||
ipv4 := &layers.IPv4{
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
SrcIP: net.IPv4(127, 0, 0, 1),
|
||||
DstIP: net.IPv4(127, 0, 0, 2),
|
||||
}
|
||||
udp := &layers.UDP{
|
||||
DstPort: 53,
|
||||
SrcPort: 45223,
|
||||
}
|
||||
if err := udp.SetNetworkLayerForChecksum(ipv4); err != nil {
|
||||
t.Error("failed to set network layer for checksum")
|
||||
return
|
||||
}
|
||||
|
||||
// Serialize the packet
|
||||
buffer := gopacket.NewSerializeBuffer()
|
||||
options := gopacket.SerializeOptions{
|
||||
ComputeChecksums: true,
|
||||
FixLengths: true,
|
||||
}
|
||||
|
||||
requestData, err := request.Pack()
|
||||
if err != nil {
|
||||
t.Errorf("got an error while packing the request message, error: %v", err)
|
||||
return
|
||||
}
|
||||
payload := gopacket.Payload(requestData)
|
||||
|
||||
if err := gopacket.SerializeLayers(buffer, options, ipv4, udp, payload); err != nil {
|
||||
t.Errorf("failed to serialize packet: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
rw := &responseWriter{
|
||||
local: &net.UDPAddr{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
Port: 55223,
|
||||
},
|
||||
remote: &net.UDPAddr{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
Port: 53,
|
||||
},
|
||||
packet: gopacket.NewPacket(
|
||||
buffer.Bytes(),
|
||||
layers.LayerTypeIPv4,
|
||||
gopacket.Default,
|
||||
),
|
||||
device: device,
|
||||
}
|
||||
if err := rw.WriteMsg(replyMessage); err != nil {
|
||||
t.Errorf("got an error while writing the local resolver response, error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -5,12 +5,15 @@ package dns
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/mitchellh/hashstructure/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -33,6 +36,7 @@ type DefaultServer struct {
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
mux sync.Mutex
|
||||
fakeResolverWG sync.WaitGroup
|
||||
server *dns.Server
|
||||
dnsMux *dns.ServeMux
|
||||
dnsMuxMap registeredHandlerMap
|
||||
@@ -105,6 +109,25 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd
|
||||
|
||||
// Start runs the listener in a go routine
|
||||
func (s *DefaultServer) Start() {
|
||||
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
|
||||
s.runtimeIP = getLastIPFromNetwork(s.wgInterface.Address().Network, 1)
|
||||
s.runtimePort = 53
|
||||
|
||||
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
||||
s.fakeResolverWG.Add(1)
|
||||
go func() {
|
||||
s.setListenerStatus(true)
|
||||
defer s.setListenerStatus(false)
|
||||
|
||||
hookID := s.filterDNSTraffic()
|
||||
s.fakeResolverWG.Wait()
|
||||
if err := s.wgInterface.GetFilter().RemovePacketHook(hookID); err != nil {
|
||||
log.Errorf("unable to remove DNS packet hook: %s", err)
|
||||
}
|
||||
}()
|
||||
return
|
||||
}
|
||||
|
||||
if s.customAddress != nil {
|
||||
s.runtimeIP = s.customAddress.Addr().String()
|
||||
s.runtimePort = int(s.customAddress.Port())
|
||||
@@ -172,6 +195,10 @@ func (s *DefaultServer) Stop() {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning {
|
||||
s.fakeResolverWG.Done()
|
||||
}
|
||||
|
||||
err = s.stopListener()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
@@ -235,12 +262,15 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
|
||||
}
|
||||
|
||||
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
// is the service should be disabled, we stop the listener
|
||||
// is the service should be disabled, we stop the listener or fake resolver
|
||||
// and proceed with a regular update to clean up the handlers and records
|
||||
if !update.ServiceEnable {
|
||||
err := s.stopListener()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning {
|
||||
s.fakeResolverWG.Done()
|
||||
} else {
|
||||
if err := s.stopListener(); err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
}
|
||||
} else if !s.listenerIsRunning {
|
||||
s.Start()
|
||||
@@ -477,3 +507,59 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *DefaultServer) filterDNSTraffic() string {
|
||||
filter := s.wgInterface.GetFilter()
|
||||
if filter == nil {
|
||||
log.Error("can't set DNS filter, filter not initialized")
|
||||
return ""
|
||||
}
|
||||
|
||||
firstLayerDecoder := layers.LayerTypeIPv4
|
||||
if s.wgInterface.Address().Network.IP.To4() == nil {
|
||||
firstLayerDecoder = layers.LayerTypeIPv6
|
||||
}
|
||||
|
||||
hook := func(packetData []byte) bool {
|
||||
// Decode the packet
|
||||
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
|
||||
|
||||
// Get the UDP layer
|
||||
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
||||
udp := udpLayer.(*layers.UDP)
|
||||
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(udp.Payload); err != nil {
|
||||
log.Tracef("parse DNS request: %v", err)
|
||||
return true
|
||||
}
|
||||
|
||||
writer := responseWriter{
|
||||
packet: packet,
|
||||
device: s.wgInterface.GetDevice().Device,
|
||||
}
|
||||
go s.dnsMux.ServeDNS(&writer, msg)
|
||||
return true
|
||||
}
|
||||
|
||||
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook)
|
||||
}
|
||||
|
||||
func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string {
|
||||
// Calculate the last IP in the CIDR range
|
||||
var endIP net.IP
|
||||
for i := 0; i < len(network.IP); i++ {
|
||||
endIP = append(endIP, network.IP[i]|^network.Mask[i])
|
||||
}
|
||||
|
||||
// convert to big.Int
|
||||
endInt := big.NewInt(0)
|
||||
endInt.SetBytes(endIP)
|
||||
|
||||
// subtract fromEnd from the last ip
|
||||
fromEndBig := big.NewInt(int64(fromEnd))
|
||||
resultInt := big.NewInt(0)
|
||||
resultInt.Sub(endInt, fromEndBig)
|
||||
|
||||
return net.IP(resultInt.Bytes()).String()
|
||||
}
|
||||
|
||||
31
client/internal/dns/server_nonandroid_test.go
Normal file
31
client/internal/dns/server_nonandroid_test.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetLastIPFromNetwork(t *testing.T) {
|
||||
tests := []struct {
|
||||
addr string
|
||||
ip string
|
||||
}{
|
||||
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
|
||||
{"192.168.0.0/30", "192.168.0.2"},
|
||||
{"192.168.0.0/16", "192.168.255.254"},
|
||||
{"192.168.0.0/24", "192.168.0.254"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
_, ipnet, err := net.ParseCIDR(tt.addr)
|
||||
if err != nil {
|
||||
t.Errorf("Error parsing CIDR: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
lastIP := getLastIPFromNetwork(ipnet, 1)
|
||||
if lastIP != tt.ip {
|
||||
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -9,10 +9,9 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
@@ -238,6 +237,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
dnsServer.updateSerial = testCase.initSerial
|
||||
// pretend we are running
|
||||
dnsServer.listenerIsRunning = true
|
||||
dnsServer.fakeResolverWG.Add(1)
|
||||
|
||||
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user