Add forwarder logic

This commit is contained in:
Zoltan Papp
2023-07-26 13:43:44 +02:00
parent 48098c994d
commit 337020bf59
14 changed files with 452 additions and 41 deletions

View File

@@ -0,0 +1,119 @@
// Code generated by bpf2go; DO NOT EDIT.
//go:build arm64be || armbe || mips || mips64 || mips64p32 || ppc64 || s390 || s390x || sparc || sparc64
// +build arm64be armbe mips mips64 mips64p32 ppc64 s390 s390x sparc sparc64
package forwarder
import (
"bytes"
_ "embed"
"fmt"
"io"
"github.com/cilium/ebpf"
)
// loadBpf returns the embedded CollectionSpec for bpf.
func loadBpf() (*ebpf.CollectionSpec, error) {
reader := bytes.NewReader(_BpfBytes)
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
if err != nil {
return nil, fmt.Errorf("can't load bpf: %w", err)
}
return spec, err
}
// loadBpfObjects loads bpf and converts it into a struct.
//
// The following types are suitable as obj argument:
//
// *bpfObjects
// *bpfPrograms
// *bpfMaps
//
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
spec, err := loadBpf()
if err != nil {
return err
}
return spec.LoadAndAssign(obj, opts)
}
// bpfSpecs contains maps and programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfSpecs struct {
bpfProgramSpecs
bpfMapSpecs
}
// bpfSpecs contains programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfProgramSpecs struct {
XdpDnsPortFwd *ebpf.ProgramSpec `ebpf:"xdp_dns_port_fwd"`
}
// bpfMapSpecs contains maps before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfMapSpecs struct {
XdpPortMap *ebpf.MapSpec `ebpf:"xdp_port_map"`
}
// bpfObjects contains all objects after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfObjects struct {
bpfPrograms
bpfMaps
}
func (o *bpfObjects) Close() error {
return _BpfClose(
&o.bpfPrograms,
&o.bpfMaps,
)
}
// bpfMaps contains all maps after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfMaps struct {
XdpPortMap *ebpf.Map `ebpf:"xdp_port_map"`
}
func (m *bpfMaps) Close() error {
return _BpfClose(
m.XdpPortMap,
)
}
// bpfPrograms contains all programs after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfPrograms struct {
XdpDnsPortFwd *ebpf.Program `ebpf:"xdp_dns_port_fwd"`
}
func (p *bpfPrograms) Close() error {
return _BpfClose(
p.XdpDnsPortFwd,
)
}
func _BpfClose(closers ...io.Closer) error {
for _, closer := range closers {
if err := closer.Close(); err != nil {
return err
}
}
return nil
}
// Do not access this directly.
//go:embed bpf_bpfeb.o
var _BpfBytes []byte

Binary file not shown.

View File

@@ -0,0 +1,119 @@
// Code generated by bpf2go; DO NOT EDIT.
//go:build 386 || amd64 || amd64p32 || arm || arm64 || mips64le || mips64p32le || mipsle || ppc64le || riscv64
// +build 386 amd64 amd64p32 arm arm64 mips64le mips64p32le mipsle ppc64le riscv64
package forwarder
import (
"bytes"
_ "embed"
"fmt"
"io"
"github.com/cilium/ebpf"
)
// loadBpf returns the embedded CollectionSpec for bpf.
func loadBpf() (*ebpf.CollectionSpec, error) {
reader := bytes.NewReader(_BpfBytes)
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
if err != nil {
return nil, fmt.Errorf("can't load bpf: %w", err)
}
return spec, err
}
// loadBpfObjects loads bpf and converts it into a struct.
//
// The following types are suitable as obj argument:
//
// *bpfObjects
// *bpfPrograms
// *bpfMaps
//
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
spec, err := loadBpf()
if err != nil {
return err
}
return spec.LoadAndAssign(obj, opts)
}
// bpfSpecs contains maps and programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfSpecs struct {
bpfProgramSpecs
bpfMapSpecs
}
// bpfSpecs contains programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfProgramSpecs struct {
XdpDnsPortFwd *ebpf.ProgramSpec `ebpf:"xdp_dns_port_fwd"`
}
// bpfMapSpecs contains maps before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfMapSpecs struct {
XdpPortMap *ebpf.MapSpec `ebpf:"xdp_port_map"`
}
// bpfObjects contains all objects after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfObjects struct {
bpfPrograms
bpfMaps
}
func (o *bpfObjects) Close() error {
return _BpfClose(
&o.bpfPrograms,
&o.bpfMaps,
)
}
// bpfMaps contains all maps after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfMaps struct {
XdpPortMap *ebpf.Map `ebpf:"xdp_port_map"`
}
func (m *bpfMaps) Close() error {
return _BpfClose(
m.XdpPortMap,
)
}
// bpfPrograms contains all programs after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfPrograms struct {
XdpDnsPortFwd *ebpf.Program `ebpf:"xdp_dns_port_fwd"`
}
func (p *bpfPrograms) Close() error {
return _BpfClose(
p.XdpDnsPortFwd,
)
}
func _BpfClose(closers ...io.Closer) error {
for _, closer := range closers {
if err := closer.Close(); err != nil {
return err
}
}
return nil
}
// Do not access this directly.
//go:embed bpf_bpfel.o
var _BpfBytes []byte

Binary file not shown.

View File

@@ -0,0 +1,89 @@
#include <stdbool.h>
#include <linux/if_ether.h> // ETH_P_IP
#include <linux/udp.h>
#include <linux/ip.h>
#include <netinet/in.h>
#include <linux/bpf.h>
#include <bpf/bpf_helpers.h>
#define bpf_printk(fmt, ...) \
({ \
char ____fmt[] = fmt; \
bpf_trace_printk(____fmt, sizeof(____fmt), ##__VA_ARGS__); \
})
const __u32 map_key_dns_ip = 0;
const __u32 map_key_dns_port = 1;
struct bpf_map_def SEC("maps") xdp_port_map = {
.type = BPF_MAP_TYPE_ARRAY,
.key_size = sizeof(__u32),
.value_size = sizeof(__u16),
.max_entries = 10,
};
__be32 dns_ip = 0;
__be16 dns_port = 0;
bool read_port_settings() {
__u16 *value;
__be32 *ip_value;
value = bpf_map_lookup_elem(&xdp_port_map, &map_key_dns_port);
if(!value) {
return false;
}
dns_port = htons(*value);
ip_value = bpf_map_lookup_elem(&xdp_port_map, &map_key_dns_ip);
if(!ip_value) {
return false;
}
dns_ip = htonl(*ip_value);
return true;
}
SEC("xdp")
int xdp_dns_port_fwd(struct xdp_md *ctx) {
if(dns_port == 0) {
if(!read_port_settings()){
return XDP_PASS;
}
bpf_printk("dns port: %d", dns_port);
bpf_printk("dns ip: %d", dns_ip);
}
void *data = (void *)(long)ctx->data;
void *data_end = (void *)(long)ctx->data_end;
struct ethhdr *eth = data;
struct iphdr *ip = (data + sizeof(struct ethhdr));
struct udphdr *udp = (data + sizeof(struct ethhdr) + sizeof(struct iphdr));
// return early if not enough data
if (data + sizeof(struct ethhdr) + sizeof(struct iphdr) + sizeof(struct udphdr) > data_end){
return XDP_PASS;
}
// skip non IPv4 packages
if (eth->h_proto != htons(ETH_P_IP)) {
return XDP_PASS;
}
if (ip->protocol != IPPROTO_UDP) {
return XDP_PASS;
}
// 2130706433 = 127.0.0.1
if (ip->daddr != dns_ip) {
return XDP_PASS;
}
// skip non dns ports
if (udp->source != htons(53)){
return XDP_PASS;
}
udp->dest = dns_port;
return XDP_PASS;
}
char _license[] SEC("license") = "GPL";

View File

@@ -0,0 +1,89 @@
package forwarder
import (
_ "embed"
"encoding/binary"
log "github.com/sirupsen/logrus"
"net"
"github.com/cilium/ebpf/link"
"github.com/cilium/ebpf/rlimit"
)
const (
mapKeyDNSIP uint32 = 0
mapKeyDNSPort uint32 = 1
)
// libbpf-dev, libc6-dev-i386-amd64-cross
//
//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc clang-14 bpf port_fwd.c -- -I /usr/x86_64-linux-gnu/include
type TrafficForwarder struct {
link link.Link
iFaceName string
}
func NewTrafficForwarder(iFace string) *TrafficForwarder {
return &TrafficForwarder{
iFaceName: iFace,
}
}
func (tf *TrafficForwarder) Start(ip string, dnsPort int) error {
log.Debugf("start DNS port forwarder")
// it required for Docker
err := rlimit.RemoveMemlock()
if err != nil {
return err
}
iFace, err := net.InterfaceByName(tf.iFaceName)
if err != nil {
return err
}
// load pre-compiled programs into the kernel.
objs := bpfObjects{}
err = loadBpfObjects(&objs, nil)
if err != nil {
return err
}
defer func() {
_ = objs.Close()
}()
err = objs.XdpPortMap.Put(mapKeyDNSIP, tf.ip2int(ip))
if err != nil {
return err
}
err = objs.XdpPortMap.Put(mapKeyDNSPort, uint16(dnsPort))
if err != nil {
return err
}
defer func() {
_ = objs.XdpPortMap.Close()
}()
tf.link, err = link.AttachXDP(link.XDPOptions{
Program: objs.XdpDnsPortFwd,
Interface: iFace.Index,
})
return err
}
func (tf *TrafficForwarder) Free() error {
if tf.link == nil {
return nil
}
err := tf.link.Close()
tf.link = nil
return err
}
func (tf *TrafficForwarder) ip2int(ipString string) uint32 {
ip := net.ParseIP(ipString)
return binary.BigEndian.Uint32(ip.To4())
}

View File

@@ -132,7 +132,7 @@ func (s *DefaultServer) Initialize() (err error) {
// When kernel space interface used it return real DNS server listener IP address
// For bind interface, fake DNS resolver address returned (second last IP address from Nebird network)
func (s *DefaultServer) DnsIP() string {
return s.service.RuntimeIP()
return s.service.ListenIp()
}
// Stop stops the server
@@ -233,16 +233,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.updateMux(muxUpdates)
s.updateLocalResolver(localRecords)
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.ListenIp(), s.service.ListenPort())
hostUpdate := s.currentConfig
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
"Learn more at: https://netbird.io/docs/how-to-guides/nameservers#local-resolver")
hostUpdate.routeAll = false
}
if err = s.hostManager.applyDNSConfig(hostUpdate); err != nil {
if err = s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
log.Error(err)
}

View File

@@ -487,7 +487,7 @@ func TestDNSServerStartStop(t *testing.T) {
d := net.Dialer{
Timeout: time.Second * 5,
}
addr := fmt.Sprintf("%s:%d", dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
addr := fmt.Sprintf("%s:%d", dnsServer.service.ListenIp(), dnsServer.service.ListenPort())
conn, err := d.DialContext(ctx, network, addr)
if err != nil {
t.Log(err)
@@ -603,7 +603,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
dnsServer.OnUpdatedHostDNSServer([]string{"8.8.8.8"})
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
resolver := newDnsResolver(dnsServer.service.ListenIp(), dnsServer.service.ListenPort())
_, err = resolver.LookupHost(context.Background(), "netbird.io")
if err != nil {
t.Errorf("failed to resolve: %s", err)
@@ -626,7 +626,7 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
defer dnsServer.Stop()
// check initial state
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
resolver := newDnsResolver(dnsServer.service.ListenIp(), dnsServer.service.ListenPort())
_, err = resolver.LookupHost(context.Background(), "netbird.io")
if err != nil {
t.Errorf("failed to resolve: %s", err)
@@ -718,7 +718,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
defer dnsServer.Stop()
// check initial state
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
resolver := newDnsResolver(dnsServer.service.ListenIp(), dnsServer.service.ListenPort())
_, err = resolver.LookupHost(context.Background(), "netbird.io")
if err != nil {
t.Errorf("failed to resolve: %s", err)

View File

@@ -13,6 +13,6 @@ type service interface {
Stop()
RegisterMux(domain string, handler dns.Handler)
DeregisterMux(key string)
RuntimePort() int
RuntimeIP() string
ListenPort() int
ListenIp() string
}

View File

@@ -3,6 +3,7 @@ package dns
import (
"context"
"fmt"
"github.com/netbirdio/netbird/client/internal/dns/forwarder"
"net"
"net/netip"
"runtime"
@@ -28,6 +29,7 @@ type serviceViaListener struct {
runtimePort int
listenerIsRunning bool
listenerFlagLock sync.Mutex
trafficForwarder *forwarder.TrafficForwarder
}
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener {
@@ -42,6 +44,7 @@ func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *service
Handler: mux,
UDPSize: 65535,
},
trafficForwarder: forwarder.NewTrafficForwarder(wgIface.Name()),
}
return s
}
@@ -72,6 +75,13 @@ func (s *serviceViaListener) Listen() error {
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
}
}()
if s.runtimePort != defaultPort {
err = s.trafficForwarder.Start(s.runtimeIP, s.runtimePort)
if err != nil {
return err
}
}
return nil
}
@@ -90,6 +100,11 @@ func (s *serviceViaListener) Stop() {
if err != nil {
log.Errorf("stopping dns server listener returned an error: %v", err)
}
err = s.trafficForwarder.Free()
if err != nil {
log.Errorf("stopping traffic forwarder returned an error: %v", err)
}
}
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
@@ -100,11 +115,11 @@ func (s *serviceViaListener) DeregisterMux(pattern string) {
s.dnsMux.HandleRemove(pattern)
}
func (s *serviceViaListener) RuntimePort() int {
return s.runtimePort
func (s *serviceViaListener) ListenPort() int {
return defaultPort
}
func (s *serviceViaListener) RuntimeIP() string {
func (s *serviceViaListener) ListenIp() string {
return s.runtimeIP
}
@@ -117,7 +132,8 @@ func (s *serviceViaListener) getFirstListenerAvailable() (string, int, error) {
if runtime.GOOS != "darwin" {
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
}
ports := []int{defaultPort, customPort}
//todo change the order back
ports := []int{customPort, defaultPort}
for _, port := range ports {
for _, ip := range ips {
addrString := fmt.Sprintf("%s:%d", ip, port)

View File

@@ -48,7 +48,7 @@ func (s *serviceViaMemory) Listen() error {
}
s.listenerIsRunning = true
log.Debugf("dns service listening on: %s", s.RuntimeIP())
log.Debugf("dns service listening on: %s", s.ListenIp())
return nil
}
@@ -75,11 +75,11 @@ func (s *serviceViaMemory) DeregisterMux(pattern string) {
s.dnsMux.HandleRemove(pattern)
}
func (s *serviceViaMemory) RuntimePort() int {
func (s *serviceViaMemory) ListenPort() int {
return s.runtimePort
}
func (s *serviceViaMemory) RuntimeIP() string {
func (s *serviceViaMemory) ListenIp() string {
return s.runtimeIP
}