diff --git a/client/firewall/uspfilter/rule.go b/client/firewall/uspfilter/rule.go index ba1a2f3ce..40872f67d 100644 --- a/client/firewall/uspfilter/rule.go +++ b/client/firewall/uspfilter/rule.go @@ -20,6 +20,8 @@ type Rule struct { dPort uint16 drop bool comment string + + udpHook func([]byte) bool } // GetRuleID returns the rule id diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 5ff5634f0..adc2d552a 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -18,7 +18,7 @@ const layerTypeAll = 0 // IFaceMapper defines subset methods of interface required for manager type IFaceMapper interface { - SetFiltering(iface.PacketFilter) error + SetFilter(iface.PacketFilter) error } // Manager userspace firewall manager @@ -64,7 +64,7 @@ func Create(iface IFaceMapper) (*Manager, error) { }, } - if err := iface.SetFiltering(m); err != nil { + if err := iface.SetFilter(m); err != nil { return nil, err } return m, nil @@ -273,6 +273,12 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b return rule.drop } case layers.LayerTypeUDP: + // if rule has UDP hook (and if we are here we match this rule) + // we ignore rule.drop and call this hook + if rule.udpHook != nil { + return rule.udpHook(packetData) + } + if rule.sPort == 0 && rule.dPort == 0 { return rule.drop } @@ -296,3 +302,58 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b func (m *Manager) SetNetwork(network *net.IPNet) { m.wgNetwork = network } + +// AddUDPPacketHook calls hook when UDP packet from given direction matched +// +// Hook function returns flag which indicates should be the matched package dropped or not +func (m *Manager) AddUDPPacketHook( + in bool, ip net.IP, dPort uint16, hook func([]byte) bool, +) string { + r := Rule{ + id: uuid.New().String(), + ip: ip, + protoLayer: layers.LayerTypeUDP, + dPort: dPort, + ipLayer: layers.LayerTypeIPv6, + direction: fw.RuleDirectionOUT, + comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort), + udpHook: hook, + } + + if ip.To4() != nil { + r.ipLayer = layers.LayerTypeIPv4 + } + + m.mutex.Lock() + var toUpdate []Rule + if in { + r.direction = fw.RuleDirectionIN + m.incomingRules = append([]Rule{r}, m.incomingRules...) + toUpdate = m.incomingRules + } else { + m.outgoingRules = append([]Rule{r}, m.outgoingRules...) + toUpdate = m.outgoingRules + } + + for i := range toUpdate { + m.rulesIndex[toUpdate[i].id] = i + } + m.mutex.Unlock() + + return r.id +} + +// RemovePacketHook removes packet hook by given ID +func (m *Manager) RemovePacketHook(hookID string) error { + for _, r := range m.incomingRules { + if r.id == hookID { + return m.DeleteRule(&r) + } + } + for _, r := range m.outgoingRules { + if r.id == hookID { + return m.DeleteRule(&r) + } + } + return fmt.Errorf("hook with given id not found") +} diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index 4dc5cfddd..eed31c627 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -15,19 +15,19 @@ import ( ) type IFaceMock struct { - SetFilteringFunc func(iface.PacketFilter) error + SetFilterFunc func(iface.PacketFilter) error } -func (i *IFaceMock) SetFiltering(iface iface.PacketFilter) error { - if i.SetFilteringFunc == nil { +func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error { + if i.SetFilterFunc == nil { return fmt.Errorf("not implemented") } - return i.SetFilteringFunc(iface) + return i.SetFilterFunc(iface) } func TestManagerCreate(t *testing.T) { ifaceMock := &IFaceMock{ - SetFilteringFunc: func(iface.PacketFilter) error { return nil }, + SetFilterFunc: func(iface.PacketFilter) error { return nil }, } m, err := Create(ifaceMock) @@ -42,10 +42,10 @@ func TestManagerCreate(t *testing.T) { } func TestManagerAddFiltering(t *testing.T) { - isSetFilteringCalled := false + isSetFilterCalled := false ifaceMock := &IFaceMock{ - SetFilteringFunc: func(iface.PacketFilter) error { - isSetFilteringCalled = true + SetFilterFunc: func(iface.PacketFilter) error { + isSetFilterCalled = true return nil }, } @@ -74,15 +74,15 @@ func TestManagerAddFiltering(t *testing.T) { return } - if !isSetFilteringCalled { - t.Error("SetFiltering was not called") + if !isSetFilterCalled { + t.Error("SetFilter was not called") return } } func TestManagerDeleteRule(t *testing.T) { ifaceMock := &IFaceMock{ - SetFilteringFunc: func(iface.PacketFilter) error { return nil }, + SetFilterFunc: func(iface.PacketFilter) error { return nil }, } m, err := Create(ifaceMock) @@ -138,9 +138,97 @@ func TestManagerDeleteRule(t *testing.T) { } } +func TestAddUDPPacketHook(t *testing.T) { + tests := []struct { + name string + in bool + expDir fw.RuleDirection + ip net.IP + dPort uint16 + hook func([]byte) bool + expectedID string + }{ + { + name: "Test Outgoing UDP Packet Hook", + in: false, + expDir: fw.RuleDirectionOUT, + ip: net.IPv4(10, 168, 0, 1), + dPort: 8000, + hook: func([]byte) bool { return true }, + }, + { + name: "Test Incoming UDP Packet Hook", + in: true, + expDir: fw.RuleDirectionIN, + ip: net.IPv6loopback, + dPort: 9000, + hook: func([]byte) bool { return false }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + manager := &Manager{ + incomingRules: []Rule{}, + outgoingRules: []Rule{}, + rulesIndex: make(map[string]int), + } + + manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) + + var addedRule Rule + if tt.in { + if len(manager.incomingRules) != 1 { + t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules)) + return + } + addedRule = manager.incomingRules[0] + } else { + if len(manager.outgoingRules) != 1 { + t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules)) + return + } + addedRule = manager.outgoingRules[0] + } + + if !tt.ip.Equal(addedRule.ip) { + t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip) + return + } + if tt.dPort != addedRule.dPort { + t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort) + return + } + if layers.LayerTypeUDP != addedRule.protoLayer { + t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer) + return + } + if tt.expDir != addedRule.direction { + t.Errorf("expected direction %d, got %d", tt.expDir, addedRule.direction) + return + } + if addedRule.udpHook == nil { + t.Errorf("expected udpHook to be set") + return + } + + // Ensure rulesIndex is correctly updated + index, ok := manager.rulesIndex[addedRule.id] + if !ok { + t.Errorf("expected rule to be in rulesIndex") + return + } + if index != 0 { + t.Errorf("expected rule index to be 0, got %d", index) + return + } + }) + } +} + func TestManagerReset(t *testing.T) { ifaceMock := &IFaceMock{ - SetFilteringFunc: func(iface.PacketFilter) error { return nil }, + SetFilterFunc: func(iface.PacketFilter) error { return nil }, } m, err := Create(ifaceMock) @@ -175,7 +263,7 @@ func TestManagerReset(t *testing.T) { func TestNotMatchByIP(t *testing.T) { ifaceMock := &IFaceMock{ - SetFilteringFunc: func(iface.PacketFilter) error { return nil }, + SetFilterFunc: func(iface.PacketFilter) error { return nil }, } m, err := Create(ifaceMock) @@ -239,12 +327,56 @@ func TestNotMatchByIP(t *testing.T) { } } +// TestRemovePacketHook tests the functionality of the RemovePacketHook method +func TestRemovePacketHook(t *testing.T) { + // creating mock iface + iface := &IFaceMock{ + SetFilterFunc: func(iface.PacketFilter) error { return nil }, + } + + // creating manager instance + manager, err := Create(iface) + if err != nil { + t.Fatalf("Failed to create Manager: %s", err) + } + + // Add a UDP packet hook + hookFunc := func(data []byte) bool { return true } + hookID := manager.AddUDPPacketHook(false, net.IPv4(192, 168, 0, 1), 8080, hookFunc) + + // Assert the hook is added by finding it in the manager's outgoing rules + found := false + for _, rule := range manager.outgoingRules { + if rule.id == hookID { + found = true + break + } + } + + if !found { + t.Fatalf("The hook was not added properly.") + } + + // Now remove the packet hook + err = manager.RemovePacketHook(hookID) + if err != nil { + t.Fatalf("Failed to remove hook: %s", err) + } + + // Assert the hook is removed by checking it in the manager's outgoing rules + for _, rule := range manager.outgoingRules { + if rule.id == hookID { + t.Fatalf("The hook was not removed properly.") + } + } +} + func TestUSPFilterCreatePerformance(t *testing.T) { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { // just check on the local interface ifaceMock := &IFaceMock{ - SetFilteringFunc: func(iface.PacketFilter) error { return nil }, + SetFilterFunc: func(iface.PacketFilter) error { return nil }, } manager, err := Create(ifaceMock) require.NoError(t, err) diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go index dcc682d3d..83d0709f7 100644 --- a/client/internal/acl/manager.go +++ b/client/internal/acl/manager.go @@ -19,7 +19,7 @@ type IFaceMapper interface { Name() string Address() iface.WGAddress IsUserspaceBind() bool - SetFiltering(iface.PacketFilter) error + SetFilter(iface.PacketFilter) error } // Manager is a ACL rules manager diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index a0060f1df..366877337 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -35,7 +35,7 @@ func TestDefaultManager(t *testing.T) { iface := mocks.NewMockIFaceMapper(ctrl) iface.EXPECT().IsUserspaceBind().Return(true) // iface.EXPECT().Name().Return("lo") - iface.EXPECT().SetFiltering(gomock.Any()) + iface.EXPECT().SetFilter(gomock.Any()) // we receive one rule from the management so for testing purposes ignore it acl, err := Create(iface) @@ -311,7 +311,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) { iface := mocks.NewMockIFaceMapper(ctrl) iface.EXPECT().IsUserspaceBind().Return(true) // iface.EXPECT().Name().Return("lo") - iface.EXPECT().SetFiltering(gomock.Any()) + iface.EXPECT().SetFilter(gomock.Any()) // we receive one rule from the management so for testing purposes ignore it acl, err := Create(iface) diff --git a/client/internal/acl/mocks/iface_mapper.go b/client/internal/acl/mocks/iface_mapper.go index cdfa7c46a..621b29513 100644 --- a/client/internal/acl/mocks/iface_mapper.go +++ b/client/internal/acl/mocks/iface_mapper.go @@ -76,16 +76,16 @@ func (mr *MockIFaceMapperMockRecorder) Name() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockIFaceMapper)(nil).Name)) } -// SetFiltering mocks base method. -func (m *MockIFaceMapper) SetFiltering(arg0 iface.PacketFilter) error { +// SetFilter mocks base method. +func (m *MockIFaceMapper) SetFilter(arg0 iface.PacketFilter) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetFiltering", arg0) + ret := m.ctrl.Call(m, "SetFilter", arg0) ret0, _ := ret[0].(error) return ret0 } -// SetFiltering indicates an expected call of SetFiltering. -func (mr *MockIFaceMapperMockRecorder) SetFiltering(arg0 interface{}) *gomock.Call { +// SetFilter indicates an expected call of SetFilter. +func (mr *MockIFaceMapperMockRecorder) SetFilter(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetFiltering", reflect.TypeOf((*MockIFaceMapper)(nil).SetFiltering), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetFilter", reflect.TypeOf((*MockIFaceMapper)(nil).SetFilter), arg0) } diff --git a/client/internal/dns/response_writer.go b/client/internal/dns/response_writer.go new file mode 100644 index 000000000..af02971b9 --- /dev/null +++ b/client/internal/dns/response_writer.go @@ -0,0 +1,103 @@ +package dns + +import ( + "fmt" + "net" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/miekg/dns" + "golang.zx2c4.com/wireguard/tun" +) + +type responseWriter struct { + local net.Addr + remote net.Addr + packet gopacket.Packet + device tun.Device +} + +// LocalAddr returns the net.Addr of the server +func (r *responseWriter) LocalAddr() net.Addr { + return r.local +} + +// RemoteAddr returns the net.Addr of the client that sent the current request. +func (r *responseWriter) RemoteAddr() net.Addr { + return r.remote +} + +// WriteMsg writes a reply back to the client. +func (r *responseWriter) WriteMsg(msg *dns.Msg) error { + buff, err := msg.Pack() + if err != nil { + return err + } + _, err = r.Write(buff) + return err +} + +// Write writes a raw buffer back to the client. +func (r *responseWriter) Write(data []byte) (int, error) { + var ip gopacket.SerializableLayer + + // Get the UDP layer + udpLayer := r.packet.Layer(layers.LayerTypeUDP) + udp := udpLayer.(*layers.UDP) + + // Swap the source and destination addresses for the response + udp.SrcPort, udp.DstPort = udp.DstPort, udp.SrcPort + + // Check if it's an IPv4 packet + if ipv4Layer := r.packet.Layer(layers.LayerTypeIPv4); ipv4Layer != nil { + ipv4 := ipv4Layer.(*layers.IPv4) + ipv4.SrcIP, ipv4.DstIP = ipv4.DstIP, ipv4.SrcIP + ip = ipv4 + } else if ipv6Layer := r.packet.Layer(layers.LayerTypeIPv6); ipv6Layer != nil { + ipv6 := ipv6Layer.(*layers.IPv6) + ipv6.SrcIP, ipv6.DstIP = ipv6.DstIP, ipv6.SrcIP + ip = ipv6 + } + + if err := udp.SetNetworkLayerForChecksum(ip.(gopacket.NetworkLayer)); err != nil { + return 0, fmt.Errorf("failed to set network layer for checksum: %v", err) + } + + // Serialize the packet + buffer := gopacket.NewSerializeBuffer() + options := gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + + payload := gopacket.Payload(data) + err := gopacket.SerializeLayers(buffer, options, ip, udp, payload) + if err != nil { + return 0, fmt.Errorf("failed to serialize packet: %v", err) + } + + send := buffer.Bytes() + sendBuffer := make([]byte, 40, len(send)+40) + sendBuffer = append(sendBuffer, send...) + + return r.device.Write([][]byte{sendBuffer}, 40) +} + +// Close closes the connection. +func (r *responseWriter) Close() error { + return nil +} + +// TsigStatus returns the status of the Tsig. +func (r *responseWriter) TsigStatus() error { + return nil +} + +// TsigTimersOnly sets the tsig timers only boolean. +func (r *responseWriter) TsigTimersOnly(bool) { +} + +// Hijack lets the caller take over the connection. +// After a call to Hijack(), the DNS package will not do anything with the connection. +func (r *responseWriter) Hijack() { +} diff --git a/client/internal/dns/response_writer_test.go b/client/internal/dns/response_writer_test.go new file mode 100644 index 000000000..5a0047700 --- /dev/null +++ b/client/internal/dns/response_writer_test.go @@ -0,0 +1,93 @@ +package dns + +import ( + "net" + "testing" + + "github.com/golang/mock/gomock" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/miekg/dns" + + "github.com/netbirdio/netbird/iface/mocks" +) + +func TestResponseWriterLocalAddr(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + device := mocks.NewMockDevice(ctrl) + device.EXPECT().Write(gomock.Any(), gomock.Any()) + + request := &dns.Msg{ + Question: []dns.Question{{ + Name: "google.com.", + Qtype: dns.TypeA, + Qclass: dns.TypeA, + }}, + } + + replyMessage := &dns.Msg{} + replyMessage.SetReply(request) + replyMessage.RecursionAvailable = true + replyMessage.Rcode = dns.RcodeSuccess + replyMessage.Answer = []dns.RR{ + &dns.A{ + A: net.IPv4(8, 8, 8, 8), + }, + } + + ipv4 := &layers.IPv4{ + Protocol: layers.IPProtocolUDP, + SrcIP: net.IPv4(127, 0, 0, 1), + DstIP: net.IPv4(127, 0, 0, 2), + } + udp := &layers.UDP{ + DstPort: 53, + SrcPort: 45223, + } + if err := udp.SetNetworkLayerForChecksum(ipv4); err != nil { + t.Error("failed to set network layer for checksum") + return + } + + // Serialize the packet + buffer := gopacket.NewSerializeBuffer() + options := gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + + requestData, err := request.Pack() + if err != nil { + t.Errorf("got an error while packing the request message, error: %v", err) + return + } + payload := gopacket.Payload(requestData) + + if err := gopacket.SerializeLayers(buffer, options, ipv4, udp, payload); err != nil { + t.Errorf("failed to serialize packet: %v", err) + return + } + + rw := &responseWriter{ + local: &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 55223, + }, + remote: &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 53, + }, + packet: gopacket.NewPacket( + buffer.Bytes(), + layers.LayerTypeIPv4, + gopacket.Default, + ), + device: device, + } + if err := rw.WriteMsg(replyMessage); err != nil { + t.Errorf("got an error while writing the local resolver response, error: %v", err) + return + } +} diff --git a/client/internal/dns/server_nonandroid.go b/client/internal/dns/server_nonandroid.go index ca4b25708..ec970bccf 100644 --- a/client/internal/dns/server_nonandroid.go +++ b/client/internal/dns/server_nonandroid.go @@ -5,12 +5,15 @@ package dns import ( "context" "fmt" + "math/big" "net" "net/netip" "runtime" "sync" "time" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" "github.com/miekg/dns" "github.com/mitchellh/hashstructure/v2" log "github.com/sirupsen/logrus" @@ -33,6 +36,7 @@ type DefaultServer struct { ctx context.Context ctxCancel context.CancelFunc mux sync.Mutex + fakeResolverWG sync.WaitGroup server *dns.Server dnsMux *dns.ServeMux dnsMuxMap registeredHandlerMap @@ -105,6 +109,25 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd // Start runs the listener in a go routine func (s *DefaultServer) Start() { + if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() { + s.runtimeIP = getLastIPFromNetwork(s.wgInterface.Address().Network, 1) + s.runtimePort = 53 + + s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort) + s.fakeResolverWG.Add(1) + go func() { + s.setListenerStatus(true) + defer s.setListenerStatus(false) + + hookID := s.filterDNSTraffic() + s.fakeResolverWG.Wait() + if err := s.wgInterface.GetFilter().RemovePacketHook(hookID); err != nil { + log.Errorf("unable to remove DNS packet hook: %s", err) + } + }() + return + } + if s.customAddress != nil { s.runtimeIP = s.customAddress.Addr().String() s.runtimePort = int(s.customAddress.Port()) @@ -172,6 +195,10 @@ func (s *DefaultServer) Stop() { log.Error(err) } + if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning { + s.fakeResolverWG.Done() + } + err = s.stopListener() if err != nil { log.Error(err) @@ -235,12 +262,15 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro } func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { - // is the service should be disabled, we stop the listener + // is the service should be disabled, we stop the listener or fake resolver // and proceed with a regular update to clean up the handlers and records if !update.ServiceEnable { - err := s.stopListener() - if err != nil { - log.Error(err) + if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning { + s.fakeResolverWG.Done() + } else { + if err := s.stopListener(); err != nil { + log.Error(err) + } } } else if !s.listenerIsRunning { s.Start() @@ -477,3 +507,59 @@ func (s *DefaultServer) upstreamCallbacks( } return } + +func (s *DefaultServer) filterDNSTraffic() string { + filter := s.wgInterface.GetFilter() + if filter == nil { + log.Error("can't set DNS filter, filter not initialized") + return "" + } + + firstLayerDecoder := layers.LayerTypeIPv4 + if s.wgInterface.Address().Network.IP.To4() == nil { + firstLayerDecoder = layers.LayerTypeIPv6 + } + + hook := func(packetData []byte) bool { + // Decode the packet + packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default) + + // Get the UDP layer + udpLayer := packet.Layer(layers.LayerTypeUDP) + udp := udpLayer.(*layers.UDP) + + msg := new(dns.Msg) + if err := msg.Unpack(udp.Payload); err != nil { + log.Tracef("parse DNS request: %v", err) + return true + } + + writer := responseWriter{ + packet: packet, + device: s.wgInterface.GetDevice().Device, + } + go s.dnsMux.ServeDNS(&writer, msg) + return true + } + + return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook) +} + +func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string { + // Calculate the last IP in the CIDR range + var endIP net.IP + for i := 0; i < len(network.IP); i++ { + endIP = append(endIP, network.IP[i]|^network.Mask[i]) + } + + // convert to big.Int + endInt := big.NewInt(0) + endInt.SetBytes(endIP) + + // subtract fromEnd from the last ip + fromEndBig := big.NewInt(int64(fromEnd)) + resultInt := big.NewInt(0) + resultInt.Sub(endInt, fromEndBig) + + return net.IP(resultInt.Bytes()).String() +} diff --git a/client/internal/dns/server_nonandroid_test.go b/client/internal/dns/server_nonandroid_test.go new file mode 100644 index 000000000..bea4f4ce8 --- /dev/null +++ b/client/internal/dns/server_nonandroid_test.go @@ -0,0 +1,31 @@ +package dns + +import ( + "net" + "testing" +) + +func TestGetLastIPFromNetwork(t *testing.T) { + tests := []struct { + addr string + ip string + }{ + {"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"}, + {"192.168.0.0/30", "192.168.0.2"}, + {"192.168.0.0/16", "192.168.255.254"}, + {"192.168.0.0/24", "192.168.0.254"}, + } + + for _, tt := range tests { + _, ipnet, err := net.ParseCIDR(tt.addr) + if err != nil { + t.Errorf("Error parsing CIDR: %v", err) + return + } + + lastIP := getLastIPFromNetwork(ipnet, 1) + if lastIP != tt.ip { + t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP) + } + } +} diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 260ac7d67..201392105 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -9,10 +9,9 @@ import ( "testing" "time" - "github.com/netbirdio/netbird/client/internal/stdnet" - "github.com/miekg/dns" + "github.com/netbirdio/netbird/client/internal/stdnet" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/iface" ) @@ -238,6 +237,7 @@ func TestUpdateDNSServer(t *testing.T) { dnsServer.updateSerial = testCase.initSerial // pretend we are running dnsServer.listenerIsRunning = true + dnsServer.fakeResolverWG.Add(1) err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate) if err != nil { diff --git a/iface/device_wrapper.go b/iface/device_wrapper.go index d6e39b81b..2fa219395 100644 --- a/iface/device_wrapper.go +++ b/iface/device_wrapper.go @@ -15,6 +15,15 @@ type PacketFilter interface { // DropIncoming filter incoming packets from external sources to host DropIncoming(packetData []byte) bool + // AddUDPPacketHook calls hook when UDP packet from given direction matched + // + // Hook function returns flag which indicates should be the matched package dropped or not. + // Hook function receives raw network packet data as argument. + AddUDPPacketHook(in bool, ip net.IP, dPort uint16, hook func(packet []byte) bool) string + + // RemovePacketHook removes hook by ID + RemovePacketHook(hookID string) error + // SetNetwork of the wireguard interface to which filtering applied SetNetwork(*net.IPNet) } @@ -82,8 +91,8 @@ func (d *DeviceWrapper) Write(bufs [][]byte, offset int) (int, error) { return n, err } -// SetFiltering sets packet filter to device -func (d *DeviceWrapper) SetFiltering(filter PacketFilter) { +// SetFilter sets packet filter to device +func (d *DeviceWrapper) SetFilter(filter PacketFilter) { d.mutex.Lock() d.filter = filter d.mutex.Unlock() diff --git a/iface/device_wrapper_test.go b/iface/device_wrapper_test.go index 9d1045100..9f5386587 100644 --- a/iface/device_wrapper_test.go +++ b/iface/device_wrapper_test.go @@ -14,13 +14,6 @@ func TestDeviceWrapperRead(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - tun := mocks.NewMockDevice(ctrl) - filter := mocks.NewMockPacketFilter(ctrl) - - mockBufs := [][]byte{{}} - mockSizes := []int{0} - mockOffset := 0 - t.Run("read ICMP", func(t *testing.T) { ipLayer := &layers.IPv4{ Version: 4, @@ -46,6 +39,11 @@ func TestDeviceWrapperRead(t *testing.T) { return } + mockBufs := [][]byte{{}} + mockSizes := []int{0} + mockOffset := 0 + + tun := mocks.NewMockDevice(ctrl) tun.EXPECT().Read(mockBufs, mockSizes, mockOffset). DoAndReturn(func(bufs [][]byte, sizes []int, offset int) (int, error) { bufs[0] = buffer.Bytes() @@ -95,7 +93,10 @@ func TestDeviceWrapperRead(t *testing.T) { return } + mockBufs := [][]byte{buffer.Bytes()} + mockBufs[0] = buffer.Bytes() + tun := mocks.NewMockDevice(ctrl) tun.EXPECT().Write(mockBufs, 0).Return(1, nil) wrapped := newDeviceWrapper(tun) @@ -138,10 +139,13 @@ func TestDeviceWrapperRead(t *testing.T) { return } - mockBufs = [][]byte{} + mockBufs := [][]byte{} + tun := mocks.NewMockDevice(ctrl) tun.EXPECT().Write(mockBufs, 0).Return(0, nil) - filter.EXPECT().DropOutput(gomock.Any()).Return(true) + + filter := mocks.NewMockPacketFilter(ctrl) + filter.EXPECT().DropIncoming(gomock.Any()).Return(true) wrapped := newDeviceWrapper(tun) wrapped.filter = filter @@ -188,13 +192,15 @@ func TestDeviceWrapperRead(t *testing.T) { mockSizes := []int{0} mockOffset := 0 + tun := mocks.NewMockDevice(ctrl) tun.EXPECT().Read(mockBufs, mockSizes, mockOffset). DoAndReturn(func(bufs [][]byte, sizes []int, offset int) (int, error) { bufs[0] = buffer.Bytes() sizes[0] = len(bufs[0]) return 1, nil }) - filter.EXPECT().DropInput(gomock.Any()).Return(true) + filter := mocks.NewMockPacketFilter(ctrl) + filter.EXPECT().DropOutgoing(gomock.Any()).Return(true) wrapped := newDeviceWrapper(tun) wrapped.filter = filter diff --git a/iface/iface.go b/iface/iface.go index 100e16982..4fcc064d1 100644 --- a/iface/iface.go +++ b/iface/iface.go @@ -23,6 +23,7 @@ type WGIface struct { configurer wGConfigurer mu sync.Mutex userspaceBind bool + filter PacketFilter } // IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind @@ -120,8 +121,8 @@ func (w *WGIface) Close() error { return w.tun.Close() } -// SetFiltering sets packet filters for the userspace impelemntation -func (w *WGIface) SetFiltering(filter PacketFilter) error { +// SetFilter sets packet filters for the userspace impelemntation +func (w *WGIface) SetFilter(filter PacketFilter) error { w.mu.Lock() defer w.mu.Unlock() @@ -129,7 +130,25 @@ func (w *WGIface) SetFiltering(filter PacketFilter) error { return fmt.Errorf("userspace packet filtering not handled on this device") } - filter.SetNetwork(w.tun.address.Network) - w.tun.wrapper.SetFiltering(filter) + w.filter = filter + w.filter.SetNetwork(w.tun.address.Network) + + w.tun.wrapper.SetFilter(filter) return nil } + +// GetFilter returns packet filter used by interface if it uses userspace device implementation +func (w *WGIface) GetFilter() PacketFilter { + w.mu.Lock() + defer w.mu.Unlock() + + return w.filter +} + +// GetDevice to interact with raw device (with filtering) +func (w *WGIface) GetDevice() *DeviceWrapper { + w.mu.Lock() + defer w.mu.Unlock() + + return w.tun.wrapper +} diff --git a/iface/mocks/README.md b/iface/mocks/README.md new file mode 100644 index 000000000..7d3ea1e2c --- /dev/null +++ b/iface/mocks/README.md @@ -0,0 +1,7 @@ +## Mocks + +To generate (or refresh) mocks from iface package interfaces please install [mockgen](https://github.com/golang/mock). +Run this command to update PacketFilter mock: +```bash +mockgen -destination iface/mocks/filter.go -package mocks github.com/netbirdio/netbird/iface PacketFilter +``` diff --git a/iface/mocks/filter.go b/iface/mocks/filter.go index 9537520a2..2d80d69f1 100644 --- a/iface/mocks/filter.go +++ b/iface/mocks/filter.go @@ -34,21 +34,21 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder { return m.recorder } -// DropInput mocks base method. -func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool { +// AddUDPPacketHook mocks base method. +func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func([]byte) bool) string { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DropOutgoing", arg0) - ret0, _ := ret[0].(bool) + ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(string) return ret0 } -// DropInput indicates an expected call of DropInput. -func (mr *MockPacketFilterMockRecorder) DropInput(arg0 interface{}) *gomock.Call { +// AddUDPPacketHook indicates an expected call of AddUDPPacketHook. +func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3) } -// DropOutput mocks base method. +// DropIncoming mocks base method. func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "DropIncoming", arg0) @@ -56,12 +56,40 @@ func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool { return ret0 } -// DropOutput indicates an expected call of DropOutput. -func (mr *MockPacketFilterMockRecorder) DropOutput(arg0 interface{}) *gomock.Call { +// DropIncoming indicates an expected call of DropIncoming. +func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0) } +// DropOutgoing mocks base method. +func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DropOutgoing", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// DropOutgoing indicates an expected call of DropOutgoing. +func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0) +} + +// RemovePacketHook mocks base method. +func (m *MockPacketFilter) RemovePacketHook(arg0 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemovePacketHook", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// RemovePacketHook indicates an expected call of RemovePacketHook. +func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0) +} + // SetNetwork mocks base method. func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) { m.ctrl.T.Helper() diff --git a/iface/mocks/iface/mocks/filter.go b/iface/mocks/iface/mocks/filter.go new file mode 100644 index 000000000..059a2b9a0 --- /dev/null +++ b/iface/mocks/iface/mocks/filter.go @@ -0,0 +1,87 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/netbirdio/netbird/iface (interfaces: PacketFilter) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + net "net" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockPacketFilter is a mock of PacketFilter interface. +type MockPacketFilter struct { + ctrl *gomock.Controller + recorder *MockPacketFilterMockRecorder +} + +// MockPacketFilterMockRecorder is the mock recorder for MockPacketFilter. +type MockPacketFilterMockRecorder struct { + mock *MockPacketFilter +} + +// NewMockPacketFilter creates a new mock instance. +func NewMockPacketFilter(ctrl *gomock.Controller) *MockPacketFilter { + mock := &MockPacketFilter{ctrl: ctrl} + mock.recorder = &MockPacketFilterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder { + return m.recorder +} + +// AddUDPPacketHook mocks base method. +func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func(*net.UDPAddr, []byte) bool) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3) +} + +// AddUDPPacketHook indicates an expected call of AddUDPPacketHook. +func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3) +} + +// DropIncoming mocks base method. +func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DropIncoming", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// DropIncoming indicates an expected call of DropIncoming. +func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0) +} + +// DropOutgoing mocks base method. +func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DropOutgoing", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// DropOutgoing indicates an expected call of DropOutgoing. +func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0) +} + +// SetNetwork mocks base method. +func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetNetwork", arg0) +} + +// SetNetwork indicates an expected call of SetNetwork. +func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0) +}