From 1d9feab2d9c160667d6df0d84144e59904996ae6 Mon Sep 17 00:00:00 2001 From: Givi Khojanashvili Date: Thu, 8 Jun 2023 13:46:57 +0400 Subject: [PATCH] Feat fake dns address (#902) Works only with userspace implementation: 1. Configure host to solve DNS requests via a fake DSN server address in the Netbird network. 2. Add to firewall catch rule for these DNS requests. 3. Resolve these DNS requests and respond by writing directly to wireguard device. --- client/firewall/uspfilter/rule.go | 2 + client/firewall/uspfilter/uspfilter.go | 65 ++++++- client/firewall/uspfilter/uspfilter_test.go | 160 ++++++++++++++++-- client/internal/acl/manager.go | 2 +- client/internal/acl/manager_test.go | 4 +- client/internal/acl/mocks/iface_mapper.go | 12 +- client/internal/dns/response_writer.go | 103 +++++++++++ client/internal/dns/response_writer_test.go | 93 ++++++++++ client/internal/dns/server_nonandroid.go | 94 +++++++++- client/internal/dns/server_nonandroid_test.go | 31 ++++ client/internal/dns/server_test.go | 4 +- iface/device_wrapper.go | 13 +- iface/device_wrapper_test.go | 26 +-- iface/iface.go | 27 ++- iface/mocks/README.md | 7 + iface/mocks/filter.go | 48 ++++-- iface/mocks/iface/mocks/filter.go | 87 ++++++++++ 17 files changed, 721 insertions(+), 57 deletions(-) create mode 100644 client/internal/dns/response_writer.go create mode 100644 client/internal/dns/response_writer_test.go create mode 100644 client/internal/dns/server_nonandroid_test.go create mode 100644 iface/mocks/README.md create mode 100644 iface/mocks/iface/mocks/filter.go diff --git a/client/firewall/uspfilter/rule.go b/client/firewall/uspfilter/rule.go index ba1a2f3ce..40872f67d 100644 --- a/client/firewall/uspfilter/rule.go +++ b/client/firewall/uspfilter/rule.go @@ -20,6 +20,8 @@ type Rule struct { dPort uint16 drop bool comment string + + udpHook func([]byte) bool } // GetRuleID returns the rule id diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 5ff5634f0..adc2d552a 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -18,7 +18,7 @@ const layerTypeAll = 0 // IFaceMapper defines subset methods of interface required for manager type IFaceMapper interface { - SetFiltering(iface.PacketFilter) error + SetFilter(iface.PacketFilter) error } // Manager userspace firewall manager @@ -64,7 +64,7 @@ func Create(iface IFaceMapper) (*Manager, error) { }, } - if err := iface.SetFiltering(m); err != nil { + if err := iface.SetFilter(m); err != nil { return nil, err } return m, nil @@ -273,6 +273,12 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b return rule.drop } case layers.LayerTypeUDP: + // if rule has UDP hook (and if we are here we match this rule) + // we ignore rule.drop and call this hook + if rule.udpHook != nil { + return rule.udpHook(packetData) + } + if rule.sPort == 0 && rule.dPort == 0 { return rule.drop } @@ -296,3 +302,58 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b func (m *Manager) SetNetwork(network *net.IPNet) { m.wgNetwork = network } + +// AddUDPPacketHook calls hook when UDP packet from given direction matched +// +// Hook function returns flag which indicates should be the matched package dropped or not +func (m *Manager) AddUDPPacketHook( + in bool, ip net.IP, dPort uint16, hook func([]byte) bool, +) string { + r := Rule{ + id: uuid.New().String(), + ip: ip, + protoLayer: layers.LayerTypeUDP, + dPort: dPort, + ipLayer: layers.LayerTypeIPv6, + direction: fw.RuleDirectionOUT, + comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort), + udpHook: hook, + } + + if ip.To4() != nil { + r.ipLayer = layers.LayerTypeIPv4 + } + + m.mutex.Lock() + var toUpdate []Rule + if in { + r.direction = fw.RuleDirectionIN + m.incomingRules = append([]Rule{r}, m.incomingRules...) + toUpdate = m.incomingRules + } else { + m.outgoingRules = append([]Rule{r}, m.outgoingRules...) + toUpdate = m.outgoingRules + } + + for i := range toUpdate { + m.rulesIndex[toUpdate[i].id] = i + } + m.mutex.Unlock() + + return r.id +} + +// RemovePacketHook removes packet hook by given ID +func (m *Manager) RemovePacketHook(hookID string) error { + for _, r := range m.incomingRules { + if r.id == hookID { + return m.DeleteRule(&r) + } + } + for _, r := range m.outgoingRules { + if r.id == hookID { + return m.DeleteRule(&r) + } + } + return fmt.Errorf("hook with given id not found") +} diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index 4dc5cfddd..eed31c627 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -15,19 +15,19 @@ import ( ) type IFaceMock struct { - SetFilteringFunc func(iface.PacketFilter) error + SetFilterFunc func(iface.PacketFilter) error } -func (i *IFaceMock) SetFiltering(iface iface.PacketFilter) error { - if i.SetFilteringFunc == nil { +func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error { + if i.SetFilterFunc == nil { return fmt.Errorf("not implemented") } - return i.SetFilteringFunc(iface) + return i.SetFilterFunc(iface) } func TestManagerCreate(t *testing.T) { ifaceMock := &IFaceMock{ - SetFilteringFunc: func(iface.PacketFilter) error { return nil }, + SetFilterFunc: func(iface.PacketFilter) error { return nil }, } m, err := Create(ifaceMock) @@ -42,10 +42,10 @@ func TestManagerCreate(t *testing.T) { } func TestManagerAddFiltering(t *testing.T) { - isSetFilteringCalled := false + isSetFilterCalled := false ifaceMock := &IFaceMock{ - SetFilteringFunc: func(iface.PacketFilter) error { - isSetFilteringCalled = true + SetFilterFunc: func(iface.PacketFilter) error { + isSetFilterCalled = true return nil }, } @@ -74,15 +74,15 @@ func TestManagerAddFiltering(t *testing.T) { return } - if !isSetFilteringCalled { - t.Error("SetFiltering was not called") + if !isSetFilterCalled { + t.Error("SetFilter was not called") return } } func TestManagerDeleteRule(t *testing.T) { ifaceMock := &IFaceMock{ - SetFilteringFunc: func(iface.PacketFilter) error { return nil }, + SetFilterFunc: func(iface.PacketFilter) error { return nil }, } m, err := Create(ifaceMock) @@ -138,9 +138,97 @@ func TestManagerDeleteRule(t *testing.T) { } } +func TestAddUDPPacketHook(t *testing.T) { + tests := []struct { + name string + in bool + expDir fw.RuleDirection + ip net.IP + dPort uint16 + hook func([]byte) bool + expectedID string + }{ + { + name: "Test Outgoing UDP Packet Hook", + in: false, + expDir: fw.RuleDirectionOUT, + ip: net.IPv4(10, 168, 0, 1), + dPort: 8000, + hook: func([]byte) bool { return true }, + }, + { + name: "Test Incoming UDP Packet Hook", + in: true, + expDir: fw.RuleDirectionIN, + ip: net.IPv6loopback, + dPort: 9000, + hook: func([]byte) bool { return false }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + manager := &Manager{ + incomingRules: []Rule{}, + outgoingRules: []Rule{}, + rulesIndex: make(map[string]int), + } + + manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) + + var addedRule Rule + if tt.in { + if len(manager.incomingRules) != 1 { + t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules)) + return + } + addedRule = manager.incomingRules[0] + } else { + if len(manager.outgoingRules) != 1 { + t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules)) + return + } + addedRule = manager.outgoingRules[0] + } + + if !tt.ip.Equal(addedRule.ip) { + t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip) + return + } + if tt.dPort != addedRule.dPort { + t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort) + return + } + if layers.LayerTypeUDP != addedRule.protoLayer { + t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer) + return + } + if tt.expDir != addedRule.direction { + t.Errorf("expected direction %d, got %d", tt.expDir, addedRule.direction) + return + } + if addedRule.udpHook == nil { + t.Errorf("expected udpHook to be set") + return + } + + // Ensure rulesIndex is correctly updated + index, ok := manager.rulesIndex[addedRule.id] + if !ok { + t.Errorf("expected rule to be in rulesIndex") + return + } + if index != 0 { + t.Errorf("expected rule index to be 0, got %d", index) + return + } + }) + } +} + func TestManagerReset(t *testing.T) { ifaceMock := &IFaceMock{ - SetFilteringFunc: func(iface.PacketFilter) error { return nil }, + SetFilterFunc: func(iface.PacketFilter) error { return nil }, } m, err := Create(ifaceMock) @@ -175,7 +263,7 @@ func TestManagerReset(t *testing.T) { func TestNotMatchByIP(t *testing.T) { ifaceMock := &IFaceMock{ - SetFilteringFunc: func(iface.PacketFilter) error { return nil }, + SetFilterFunc: func(iface.PacketFilter) error { return nil }, } m, err := Create(ifaceMock) @@ -239,12 +327,56 @@ func TestNotMatchByIP(t *testing.T) { } } +// TestRemovePacketHook tests the functionality of the RemovePacketHook method +func TestRemovePacketHook(t *testing.T) { + // creating mock iface + iface := &IFaceMock{ + SetFilterFunc: func(iface.PacketFilter) error { return nil }, + } + + // creating manager instance + manager, err := Create(iface) + if err != nil { + t.Fatalf("Failed to create Manager: %s", err) + } + + // Add a UDP packet hook + hookFunc := func(data []byte) bool { return true } + hookID := manager.AddUDPPacketHook(false, net.IPv4(192, 168, 0, 1), 8080, hookFunc) + + // Assert the hook is added by finding it in the manager's outgoing rules + found := false + for _, rule := range manager.outgoingRules { + if rule.id == hookID { + found = true + break + } + } + + if !found { + t.Fatalf("The hook was not added properly.") + } + + // Now remove the packet hook + err = manager.RemovePacketHook(hookID) + if err != nil { + t.Fatalf("Failed to remove hook: %s", err) + } + + // Assert the hook is removed by checking it in the manager's outgoing rules + for _, rule := range manager.outgoingRules { + if rule.id == hookID { + t.Fatalf("The hook was not removed properly.") + } + } +} + func TestUSPFilterCreatePerformance(t *testing.T) { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { // just check on the local interface ifaceMock := &IFaceMock{ - SetFilteringFunc: func(iface.PacketFilter) error { return nil }, + SetFilterFunc: func(iface.PacketFilter) error { return nil }, } manager, err := Create(ifaceMock) require.NoError(t, err) diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go index dcc682d3d..83d0709f7 100644 --- a/client/internal/acl/manager.go +++ b/client/internal/acl/manager.go @@ -19,7 +19,7 @@ type IFaceMapper interface { Name() string Address() iface.WGAddress IsUserspaceBind() bool - SetFiltering(iface.PacketFilter) error + SetFilter(iface.PacketFilter) error } // Manager is a ACL rules manager diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index a0060f1df..366877337 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -35,7 +35,7 @@ func TestDefaultManager(t *testing.T) { iface := mocks.NewMockIFaceMapper(ctrl) iface.EXPECT().IsUserspaceBind().Return(true) // iface.EXPECT().Name().Return("lo") - iface.EXPECT().SetFiltering(gomock.Any()) + iface.EXPECT().SetFilter(gomock.Any()) // we receive one rule from the management so for testing purposes ignore it acl, err := Create(iface) @@ -311,7 +311,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) { iface := mocks.NewMockIFaceMapper(ctrl) iface.EXPECT().IsUserspaceBind().Return(true) // iface.EXPECT().Name().Return("lo") - iface.EXPECT().SetFiltering(gomock.Any()) + iface.EXPECT().SetFilter(gomock.Any()) // we receive one rule from the management so for testing purposes ignore it acl, err := Create(iface) diff --git a/client/internal/acl/mocks/iface_mapper.go b/client/internal/acl/mocks/iface_mapper.go index cdfa7c46a..621b29513 100644 --- a/client/internal/acl/mocks/iface_mapper.go +++ b/client/internal/acl/mocks/iface_mapper.go @@ -76,16 +76,16 @@ func (mr *MockIFaceMapperMockRecorder) Name() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockIFaceMapper)(nil).Name)) } -// SetFiltering mocks base method. -func (m *MockIFaceMapper) SetFiltering(arg0 iface.PacketFilter) error { +// SetFilter mocks base method. +func (m *MockIFaceMapper) SetFilter(arg0 iface.PacketFilter) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetFiltering", arg0) + ret := m.ctrl.Call(m, "SetFilter", arg0) ret0, _ := ret[0].(error) return ret0 } -// SetFiltering indicates an expected call of SetFiltering. -func (mr *MockIFaceMapperMockRecorder) SetFiltering(arg0 interface{}) *gomock.Call { +// SetFilter indicates an expected call of SetFilter. +func (mr *MockIFaceMapperMockRecorder) SetFilter(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetFiltering", reflect.TypeOf((*MockIFaceMapper)(nil).SetFiltering), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetFilter", reflect.TypeOf((*MockIFaceMapper)(nil).SetFilter), arg0) } diff --git a/client/internal/dns/response_writer.go b/client/internal/dns/response_writer.go new file mode 100644 index 000000000..af02971b9 --- /dev/null +++ b/client/internal/dns/response_writer.go @@ -0,0 +1,103 @@ +package dns + +import ( + "fmt" + "net" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/miekg/dns" + "golang.zx2c4.com/wireguard/tun" +) + +type responseWriter struct { + local net.Addr + remote net.Addr + packet gopacket.Packet + device tun.Device +} + +// LocalAddr returns the net.Addr of the server +func (r *responseWriter) LocalAddr() net.Addr { + return r.local +} + +// RemoteAddr returns the net.Addr of the client that sent the current request. +func (r *responseWriter) RemoteAddr() net.Addr { + return r.remote +} + +// WriteMsg writes a reply back to the client. +func (r *responseWriter) WriteMsg(msg *dns.Msg) error { + buff, err := msg.Pack() + if err != nil { + return err + } + _, err = r.Write(buff) + return err +} + +// Write writes a raw buffer back to the client. +func (r *responseWriter) Write(data []byte) (int, error) { + var ip gopacket.SerializableLayer + + // Get the UDP layer + udpLayer := r.packet.Layer(layers.LayerTypeUDP) + udp := udpLayer.(*layers.UDP) + + // Swap the source and destination addresses for the response + udp.SrcPort, udp.DstPort = udp.DstPort, udp.SrcPort + + // Check if it's an IPv4 packet + if ipv4Layer := r.packet.Layer(layers.LayerTypeIPv4); ipv4Layer != nil { + ipv4 := ipv4Layer.(*layers.IPv4) + ipv4.SrcIP, ipv4.DstIP = ipv4.DstIP, ipv4.SrcIP + ip = ipv4 + } else if ipv6Layer := r.packet.Layer(layers.LayerTypeIPv6); ipv6Layer != nil { + ipv6 := ipv6Layer.(*layers.IPv6) + ipv6.SrcIP, ipv6.DstIP = ipv6.DstIP, ipv6.SrcIP + ip = ipv6 + } + + if err := udp.SetNetworkLayerForChecksum(ip.(gopacket.NetworkLayer)); err != nil { + return 0, fmt.Errorf("failed to set network layer for checksum: %v", err) + } + + // Serialize the packet + buffer := gopacket.NewSerializeBuffer() + options := gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + + payload := gopacket.Payload(data) + err := gopacket.SerializeLayers(buffer, options, ip, udp, payload) + if err != nil { + return 0, fmt.Errorf("failed to serialize packet: %v", err) + } + + send := buffer.Bytes() + sendBuffer := make([]byte, 40, len(send)+40) + sendBuffer = append(sendBuffer, send...) + + return r.device.Write([][]byte{sendBuffer}, 40) +} + +// Close closes the connection. +func (r *responseWriter) Close() error { + return nil +} + +// TsigStatus returns the status of the Tsig. +func (r *responseWriter) TsigStatus() error { + return nil +} + +// TsigTimersOnly sets the tsig timers only boolean. +func (r *responseWriter) TsigTimersOnly(bool) { +} + +// Hijack lets the caller take over the connection. +// After a call to Hijack(), the DNS package will not do anything with the connection. +func (r *responseWriter) Hijack() { +} diff --git a/client/internal/dns/response_writer_test.go b/client/internal/dns/response_writer_test.go new file mode 100644 index 000000000..5a0047700 --- /dev/null +++ b/client/internal/dns/response_writer_test.go @@ -0,0 +1,93 @@ +package dns + +import ( + "net" + "testing" + + "github.com/golang/mock/gomock" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/miekg/dns" + + "github.com/netbirdio/netbird/iface/mocks" +) + +func TestResponseWriterLocalAddr(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + device := mocks.NewMockDevice(ctrl) + device.EXPECT().Write(gomock.Any(), gomock.Any()) + + request := &dns.Msg{ + Question: []dns.Question{{ + Name: "google.com.", + Qtype: dns.TypeA, + Qclass: dns.TypeA, + }}, + } + + replyMessage := &dns.Msg{} + replyMessage.SetReply(request) + replyMessage.RecursionAvailable = true + replyMessage.Rcode = dns.RcodeSuccess + replyMessage.Answer = []dns.RR{ + &dns.A{ + A: net.IPv4(8, 8, 8, 8), + }, + } + + ipv4 := &layers.IPv4{ + Protocol: layers.IPProtocolUDP, + SrcIP: net.IPv4(127, 0, 0, 1), + DstIP: net.IPv4(127, 0, 0, 2), + } + udp := &layers.UDP{ + DstPort: 53, + SrcPort: 45223, + } + if err := udp.SetNetworkLayerForChecksum(ipv4); err != nil { + t.Error("failed to set network layer for checksum") + return + } + + // Serialize the packet + buffer := gopacket.NewSerializeBuffer() + options := gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + + requestData, err := request.Pack() + if err != nil { + t.Errorf("got an error while packing the request message, error: %v", err) + return + } + payload := gopacket.Payload(requestData) + + if err := gopacket.SerializeLayers(buffer, options, ipv4, udp, payload); err != nil { + t.Errorf("failed to serialize packet: %v", err) + return + } + + rw := &responseWriter{ + local: &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 55223, + }, + remote: &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 53, + }, + packet: gopacket.NewPacket( + buffer.Bytes(), + layers.LayerTypeIPv4, + gopacket.Default, + ), + device: device, + } + if err := rw.WriteMsg(replyMessage); err != nil { + t.Errorf("got an error while writing the local resolver response, error: %v", err) + return + } +} diff --git a/client/internal/dns/server_nonandroid.go b/client/internal/dns/server_nonandroid.go index ca4b25708..ec970bccf 100644 --- a/client/internal/dns/server_nonandroid.go +++ b/client/internal/dns/server_nonandroid.go @@ -5,12 +5,15 @@ package dns import ( "context" "fmt" + "math/big" "net" "net/netip" "runtime" "sync" "time" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" "github.com/miekg/dns" "github.com/mitchellh/hashstructure/v2" log "github.com/sirupsen/logrus" @@ -33,6 +36,7 @@ type DefaultServer struct { ctx context.Context ctxCancel context.CancelFunc mux sync.Mutex + fakeResolverWG sync.WaitGroup server *dns.Server dnsMux *dns.ServeMux dnsMuxMap registeredHandlerMap @@ -105,6 +109,25 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd // Start runs the listener in a go routine func (s *DefaultServer) Start() { + if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() { + s.runtimeIP = getLastIPFromNetwork(s.wgInterface.Address().Network, 1) + s.runtimePort = 53 + + s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort) + s.fakeResolverWG.Add(1) + go func() { + s.setListenerStatus(true) + defer s.setListenerStatus(false) + + hookID := s.filterDNSTraffic() + s.fakeResolverWG.Wait() + if err := s.wgInterface.GetFilter().RemovePacketHook(hookID); err != nil { + log.Errorf("unable to remove DNS packet hook: %s", err) + } + }() + return + } + if s.customAddress != nil { s.runtimeIP = s.customAddress.Addr().String() s.runtimePort = int(s.customAddress.Port()) @@ -172,6 +195,10 @@ func (s *DefaultServer) Stop() { log.Error(err) } + if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning { + s.fakeResolverWG.Done() + } + err = s.stopListener() if err != nil { log.Error(err) @@ -235,12 +262,15 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro } func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { - // is the service should be disabled, we stop the listener + // is the service should be disabled, we stop the listener or fake resolver // and proceed with a regular update to clean up the handlers and records if !update.ServiceEnable { - err := s.stopListener() - if err != nil { - log.Error(err) + if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning { + s.fakeResolverWG.Done() + } else { + if err := s.stopListener(); err != nil { + log.Error(err) + } } } else if !s.listenerIsRunning { s.Start() @@ -477,3 +507,59 @@ func (s *DefaultServer) upstreamCallbacks( } return } + +func (s *DefaultServer) filterDNSTraffic() string { + filter := s.wgInterface.GetFilter() + if filter == nil { + log.Error("can't set DNS filter, filter not initialized") + return "" + } + + firstLayerDecoder := layers.LayerTypeIPv4 + if s.wgInterface.Address().Network.IP.To4() == nil { + firstLayerDecoder = layers.LayerTypeIPv6 + } + + hook := func(packetData []byte) bool { + // Decode the packet + packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default) + + // Get the UDP layer + udpLayer := packet.Layer(layers.LayerTypeUDP) + udp := udpLayer.(*layers.UDP) + + msg := new(dns.Msg) + if err := msg.Unpack(udp.Payload); err != nil { + log.Tracef("parse DNS request: %v", err) + return true + } + + writer := responseWriter{ + packet: packet, + device: s.wgInterface.GetDevice().Device, + } + go s.dnsMux.ServeDNS(&writer, msg) + return true + } + + return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook) +} + +func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string { + // Calculate the last IP in the CIDR range + var endIP net.IP + for i := 0; i < len(network.IP); i++ { + endIP = append(endIP, network.IP[i]|^network.Mask[i]) + } + + // convert to big.Int + endInt := big.NewInt(0) + endInt.SetBytes(endIP) + + // subtract fromEnd from the last ip + fromEndBig := big.NewInt(int64(fromEnd)) + resultInt := big.NewInt(0) + resultInt.Sub(endInt, fromEndBig) + + return net.IP(resultInt.Bytes()).String() +} diff --git a/client/internal/dns/server_nonandroid_test.go b/client/internal/dns/server_nonandroid_test.go new file mode 100644 index 000000000..bea4f4ce8 --- /dev/null +++ b/client/internal/dns/server_nonandroid_test.go @@ -0,0 +1,31 @@ +package dns + +import ( + "net" + "testing" +) + +func TestGetLastIPFromNetwork(t *testing.T) { + tests := []struct { + addr string + ip string + }{ + {"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"}, + {"192.168.0.0/30", "192.168.0.2"}, + {"192.168.0.0/16", "192.168.255.254"}, + {"192.168.0.0/24", "192.168.0.254"}, + } + + for _, tt := range tests { + _, ipnet, err := net.ParseCIDR(tt.addr) + if err != nil { + t.Errorf("Error parsing CIDR: %v", err) + return + } + + lastIP := getLastIPFromNetwork(ipnet, 1) + if lastIP != tt.ip { + t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP) + } + } +} diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 260ac7d67..201392105 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -9,10 +9,9 @@ import ( "testing" "time" - "github.com/netbirdio/netbird/client/internal/stdnet" - "github.com/miekg/dns" + "github.com/netbirdio/netbird/client/internal/stdnet" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/iface" ) @@ -238,6 +237,7 @@ func TestUpdateDNSServer(t *testing.T) { dnsServer.updateSerial = testCase.initSerial // pretend we are running dnsServer.listenerIsRunning = true + dnsServer.fakeResolverWG.Add(1) err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate) if err != nil { diff --git a/iface/device_wrapper.go b/iface/device_wrapper.go index d6e39b81b..2fa219395 100644 --- a/iface/device_wrapper.go +++ b/iface/device_wrapper.go @@ -15,6 +15,15 @@ type PacketFilter interface { // DropIncoming filter incoming packets from external sources to host DropIncoming(packetData []byte) bool + // AddUDPPacketHook calls hook when UDP packet from given direction matched + // + // Hook function returns flag which indicates should be the matched package dropped or not. + // Hook function receives raw network packet data as argument. + AddUDPPacketHook(in bool, ip net.IP, dPort uint16, hook func(packet []byte) bool) string + + // RemovePacketHook removes hook by ID + RemovePacketHook(hookID string) error + // SetNetwork of the wireguard interface to which filtering applied SetNetwork(*net.IPNet) } @@ -82,8 +91,8 @@ func (d *DeviceWrapper) Write(bufs [][]byte, offset int) (int, error) { return n, err } -// SetFiltering sets packet filter to device -func (d *DeviceWrapper) SetFiltering(filter PacketFilter) { +// SetFilter sets packet filter to device +func (d *DeviceWrapper) SetFilter(filter PacketFilter) { d.mutex.Lock() d.filter = filter d.mutex.Unlock() diff --git a/iface/device_wrapper_test.go b/iface/device_wrapper_test.go index 9d1045100..9f5386587 100644 --- a/iface/device_wrapper_test.go +++ b/iface/device_wrapper_test.go @@ -14,13 +14,6 @@ func TestDeviceWrapperRead(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - tun := mocks.NewMockDevice(ctrl) - filter := mocks.NewMockPacketFilter(ctrl) - - mockBufs := [][]byte{{}} - mockSizes := []int{0} - mockOffset := 0 - t.Run("read ICMP", func(t *testing.T) { ipLayer := &layers.IPv4{ Version: 4, @@ -46,6 +39,11 @@ func TestDeviceWrapperRead(t *testing.T) { return } + mockBufs := [][]byte{{}} + mockSizes := []int{0} + mockOffset := 0 + + tun := mocks.NewMockDevice(ctrl) tun.EXPECT().Read(mockBufs, mockSizes, mockOffset). DoAndReturn(func(bufs [][]byte, sizes []int, offset int) (int, error) { bufs[0] = buffer.Bytes() @@ -95,7 +93,10 @@ func TestDeviceWrapperRead(t *testing.T) { return } + mockBufs := [][]byte{buffer.Bytes()} + mockBufs[0] = buffer.Bytes() + tun := mocks.NewMockDevice(ctrl) tun.EXPECT().Write(mockBufs, 0).Return(1, nil) wrapped := newDeviceWrapper(tun) @@ -138,10 +139,13 @@ func TestDeviceWrapperRead(t *testing.T) { return } - mockBufs = [][]byte{} + mockBufs := [][]byte{} + tun := mocks.NewMockDevice(ctrl) tun.EXPECT().Write(mockBufs, 0).Return(0, nil) - filter.EXPECT().DropOutput(gomock.Any()).Return(true) + + filter := mocks.NewMockPacketFilter(ctrl) + filter.EXPECT().DropIncoming(gomock.Any()).Return(true) wrapped := newDeviceWrapper(tun) wrapped.filter = filter @@ -188,13 +192,15 @@ func TestDeviceWrapperRead(t *testing.T) { mockSizes := []int{0} mockOffset := 0 + tun := mocks.NewMockDevice(ctrl) tun.EXPECT().Read(mockBufs, mockSizes, mockOffset). DoAndReturn(func(bufs [][]byte, sizes []int, offset int) (int, error) { bufs[0] = buffer.Bytes() sizes[0] = len(bufs[0]) return 1, nil }) - filter.EXPECT().DropInput(gomock.Any()).Return(true) + filter := mocks.NewMockPacketFilter(ctrl) + filter.EXPECT().DropOutgoing(gomock.Any()).Return(true) wrapped := newDeviceWrapper(tun) wrapped.filter = filter diff --git a/iface/iface.go b/iface/iface.go index 100e16982..4fcc064d1 100644 --- a/iface/iface.go +++ b/iface/iface.go @@ -23,6 +23,7 @@ type WGIface struct { configurer wGConfigurer mu sync.Mutex userspaceBind bool + filter PacketFilter } // IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind @@ -120,8 +121,8 @@ func (w *WGIface) Close() error { return w.tun.Close() } -// SetFiltering sets packet filters for the userspace impelemntation -func (w *WGIface) SetFiltering(filter PacketFilter) error { +// SetFilter sets packet filters for the userspace impelemntation +func (w *WGIface) SetFilter(filter PacketFilter) error { w.mu.Lock() defer w.mu.Unlock() @@ -129,7 +130,25 @@ func (w *WGIface) SetFiltering(filter PacketFilter) error { return fmt.Errorf("userspace packet filtering not handled on this device") } - filter.SetNetwork(w.tun.address.Network) - w.tun.wrapper.SetFiltering(filter) + w.filter = filter + w.filter.SetNetwork(w.tun.address.Network) + + w.tun.wrapper.SetFilter(filter) return nil } + +// GetFilter returns packet filter used by interface if it uses userspace device implementation +func (w *WGIface) GetFilter() PacketFilter { + w.mu.Lock() + defer w.mu.Unlock() + + return w.filter +} + +// GetDevice to interact with raw device (with filtering) +func (w *WGIface) GetDevice() *DeviceWrapper { + w.mu.Lock() + defer w.mu.Unlock() + + return w.tun.wrapper +} diff --git a/iface/mocks/README.md b/iface/mocks/README.md new file mode 100644 index 000000000..7d3ea1e2c --- /dev/null +++ b/iface/mocks/README.md @@ -0,0 +1,7 @@ +## Mocks + +To generate (or refresh) mocks from iface package interfaces please install [mockgen](https://github.com/golang/mock). +Run this command to update PacketFilter mock: +```bash +mockgen -destination iface/mocks/filter.go -package mocks github.com/netbirdio/netbird/iface PacketFilter +``` diff --git a/iface/mocks/filter.go b/iface/mocks/filter.go index 9537520a2..2d80d69f1 100644 --- a/iface/mocks/filter.go +++ b/iface/mocks/filter.go @@ -34,21 +34,21 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder { return m.recorder } -// DropInput mocks base method. -func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool { +// AddUDPPacketHook mocks base method. +func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func([]byte) bool) string { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DropOutgoing", arg0) - ret0, _ := ret[0].(bool) + ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(string) return ret0 } -// DropInput indicates an expected call of DropInput. -func (mr *MockPacketFilterMockRecorder) DropInput(arg0 interface{}) *gomock.Call { +// AddUDPPacketHook indicates an expected call of AddUDPPacketHook. +func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3) } -// DropOutput mocks base method. +// DropIncoming mocks base method. func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "DropIncoming", arg0) @@ -56,12 +56,40 @@ func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool { return ret0 } -// DropOutput indicates an expected call of DropOutput. -func (mr *MockPacketFilterMockRecorder) DropOutput(arg0 interface{}) *gomock.Call { +// DropIncoming indicates an expected call of DropIncoming. +func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0) } +// DropOutgoing mocks base method. +func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DropOutgoing", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// DropOutgoing indicates an expected call of DropOutgoing. +func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0) +} + +// RemovePacketHook mocks base method. +func (m *MockPacketFilter) RemovePacketHook(arg0 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemovePacketHook", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// RemovePacketHook indicates an expected call of RemovePacketHook. +func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0) +} + // SetNetwork mocks base method. func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) { m.ctrl.T.Helper() diff --git a/iface/mocks/iface/mocks/filter.go b/iface/mocks/iface/mocks/filter.go new file mode 100644 index 000000000..059a2b9a0 --- /dev/null +++ b/iface/mocks/iface/mocks/filter.go @@ -0,0 +1,87 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/netbirdio/netbird/iface (interfaces: PacketFilter) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + net "net" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockPacketFilter is a mock of PacketFilter interface. +type MockPacketFilter struct { + ctrl *gomock.Controller + recorder *MockPacketFilterMockRecorder +} + +// MockPacketFilterMockRecorder is the mock recorder for MockPacketFilter. +type MockPacketFilterMockRecorder struct { + mock *MockPacketFilter +} + +// NewMockPacketFilter creates a new mock instance. +func NewMockPacketFilter(ctrl *gomock.Controller) *MockPacketFilter { + mock := &MockPacketFilter{ctrl: ctrl} + mock.recorder = &MockPacketFilterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder { + return m.recorder +} + +// AddUDPPacketHook mocks base method. +func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func(*net.UDPAddr, []byte) bool) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3) +} + +// AddUDPPacketHook indicates an expected call of AddUDPPacketHook. +func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3) +} + +// DropIncoming mocks base method. +func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DropIncoming", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// DropIncoming indicates an expected call of DropIncoming. +func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0) +} + +// DropOutgoing mocks base method. +func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DropOutgoing", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// DropOutgoing indicates an expected call of DropOutgoing. +func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0) +} + +// SetNetwork mocks base method. +func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetNetwork", arg0) +} + +// SetNetwork indicates an expected call of SetNetwork. +func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0) +}