diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index a87d8a68f..2e4b9677a 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -33,3 +33,55 @@ jobs: - name: Test run: GOARCH=${{ matrix.arch }} go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./... + + test_client_on_docker: + runs-on: ubuntu-latest + steps: + - name: Install Go + uses: actions/setup-go@v2 + with: + go-version: 1.18.x + + + - name: Cache Go modules + uses: actions/cache@v2 + with: + path: ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- + + - name: Checkout code + uses: actions/checkout@v2 + + - name: Install dependencies + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libappindicator3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev + + - name: Install modules + run: go mod tidy + + - name: Generate Iface Test bin + run: go test -c -o iface-testing.bin ./iface/... + + - name: Generate RouteManager Test bin + run: go test -c -o routemanager-testing.bin ./client/internal/routemanager/... + + - name: Generate Engine Test bin + run: go test -c -o engine-testing.bin ./client/internal/*.go + + - name: Generate Peer Test bin + run: go test -c -o peer-testing.bin ./client/internal/peer/... + + - run: chmod +x *testing.bin + + - name: Run Iface tests in docker + run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/iface --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/iface-testing.bin + + - name: Run RouteManager tests in docker + run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin + + - name: Run Engine tests in docker + run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin + + - name: Run Peer tests in docker + run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin \ No newline at end of file diff --git a/client/internal/connect.go b/client/internal/connect.go index ddd8788e7..44bc54fd8 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -107,7 +107,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *nbStatus.Sta localPeerState := nbStatus.LocalPeerState{ IP: loginResp.GetPeerConfig().GetAddress(), PubKey: myPrivateKey.PublicKey().String(), - KernelInterface: iface.WireguardModExists(), + KernelInterface: iface.WireguardModuleIsLoaded(), } statusRecorder.UpdateLocalPeerState(localPeerState) diff --git a/client/internal/routemanager/nftables_linux.go b/client/internal/routemanager/nftables_linux.go index 6201301fc..0874737e7 100644 --- a/client/internal/routemanager/nftables_linux.go +++ b/client/internal/routemanager/nftables_linux.go @@ -84,8 +84,10 @@ func (n *nftablesManager) CleanRoutingRules() { n.mux.Lock() defer n.mux.Unlock() log.Debug("flushing tables") - n.conn.FlushTable(n.tableIPv6) - n.conn.FlushTable(n.tableIPv4) + if n.tableIPv4 != nil && n.tableIPv6 != nil { + n.conn.FlushTable(n.tableIPv6) + n.conn.FlushTable(n.tableIPv4) + } log.Debugf("flushing tables result in: %v error", n.conn.Flush()) } diff --git a/iface/iface_darwin.go b/iface/iface_darwin.go index c84a0fc04..22c089d27 100644 --- a/iface/iface_darwin.go +++ b/iface/iface_darwin.go @@ -34,7 +34,7 @@ func (w *WGIface) assignAddr() error { return nil } -// WireguardModExists check if we can load wireguard mod (linux only) -func WireguardModExists() bool { +// WireguardModuleIsLoaded check if we can load wireguard mod (linux only) +func WireguardModuleIsLoaded() bool { return false } diff --git a/iface/iface_linux.go b/iface/iface_linux.go index 6da54e9bd..cbbe0c6f8 100644 --- a/iface/iface_linux.go +++ b/iface/iface_linux.go @@ -1,48 +1,29 @@ package iface import ( - "errors" - "math" - "os" - "syscall" - + "fmt" log "github.com/sirupsen/logrus" "github.com/vishvananda/netlink" + "os" ) type NativeLink struct { Link *netlink.Link } -// WireguardModExists check if we can load wireguard mod (linux only) -func WireguardModExists() bool { - link := newWGLink("mustnotexist") - - // We willingly try to create a device with an invalid - // MTU here as the validation of the MTU will be performed after - // the validation of the link kind and hence allows us to check - // for the existance of the wireguard module without actually - // creating a link. - // - // As a side-effect, this will also let the kernel lazy-load - // the wireguard module. - link.attrs.MTU = math.MaxInt - - err := netlink.LinkAdd(link) - - return errors.Is(err, syscall.EINVAL) -} - // Create creates a new Wireguard interface, sets a given IP and brings it up. // Will reuse an existing one. func (w *WGIface) Create() error { w.mu.Lock() defer w.mu.Unlock() - if WireguardModExists() { + if WireguardModuleIsLoaded() { log.Info("using kernel WireGuard") return w.createWithKernel() } else { + if !tunModuleIsLoaded() { + return fmt.Errorf("couldn't check or load tun module") + } log.Info("using userspace WireGuard") return w.createWithUserspace() } diff --git a/iface/iface_windows.go b/iface/iface_windows.go index 4c6d4a4ab..d38cd3dc4 100644 --- a/iface/iface_windows.go +++ b/iface/iface_windows.go @@ -58,7 +58,7 @@ func (w *WGIface) UpdateAddr(newAddr string) error { return w.assignAddr(luid) } -// WireguardModExists check if we can load wireguard mod (linux only) -func WireguardModExists() bool { +// WireguardModuleIsLoaded check if we can load wireguard mod (linux only) +func WireguardModuleIsLoaded() bool { return false } diff --git a/iface/module_linux.go b/iface/module_linux.go new file mode 100644 index 000000000..61cbdb967 --- /dev/null +++ b/iface/module_linux.go @@ -0,0 +1,349 @@ +// Package iface provides wireguard network interface creation and management +package iface + +import ( + "bufio" + "errors" + "fmt" + log "github.com/sirupsen/logrus" + "github.com/vishvananda/netlink" + "golang.org/x/sys/unix" + "io/fs" + "io/ioutil" + "math" + "os" + "path/filepath" + "strings" + "syscall" +) + +// Holds logic to check existence of kernel modules used by wireguard interfaces +// Copied from https://github.com/paultag/go-modprobe and +// https://github.com/pmorjan/kmod + +type status int + +const ( + defaultModuleDir = "/lib/modules" + unknown status = iota + unloaded + unloading + loading + live + inuse +) + +type module struct { + name string + path string +} + +var ( + // ErrModuleNotFound is the error resulting if a module can't be found. + ErrModuleNotFound = errors.New("module not found") + moduleLibDir = defaultModuleDir + // get the root directory for the kernel modules. If this line panics, + // it's because getModuleRoot has failed to get the uname of the running + // kernel (likely a non-POSIX system, but maybe a broken kernel?) + moduleRoot = getModuleRoot() +) + +// Get the module root (/lib/modules/$(uname -r)/) +func getModuleRoot() string { + uname := unix.Utsname{} + if err := unix.Uname(&uname); err != nil { + panic(err) + } + + i := 0 + for ; uname.Release[i] != 0; i++ { + } + + return filepath.Join(moduleLibDir, string(uname.Release[:i])) +} + +// tunModuleIsLoaded check if tun module exist, if is not attempt to load it +func tunModuleIsLoaded() bool { + _, err := os.Stat("/dev/net/tun") + if err == nil { + return true + } + + log.Infof("couldn't access device /dev/net/tun, go error %v, "+ + "will attempt to load tun module, if running on container add flag --cap-add=NET_ADMIN", err) + + tunLoaded, err := tryToLoadModule("tun") + if err != nil { + log.Errorf("unable to find or load tun module, got error: %v", err) + } + return tunLoaded +} + +// WireguardModuleIsLoaded check if we can load wireguard mod (linux only) +func WireguardModuleIsLoaded() bool { + if canCreateFakeWireguardInterface() { + return true + } + + loaded, err := tryToLoadModule("wireguard") + if err != nil { + log.Info(err) + return false + } + + return loaded +} + +func canCreateFakeWireguardInterface() bool { + link := newWGLink("mustnotexist") + + // We willingly try to create a device with an invalid + // MTU here as the validation of the MTU will be performed after + // the validation of the link kind and hence allows us to check + // for the existance of the wireguard module without actually + // creating a link. + // + // As a side-effect, this will also let the kernel lazy-load + // the wireguard module. + link.attrs.MTU = math.MaxInt + + err := netlink.LinkAdd(link) + + return errors.Is(err, syscall.EINVAL) +} + +func tryToLoadModule(moduleName string) (bool, error) { + if isModuleEnabled(moduleName) { + return true, nil + } + modulePath, err := getModulePath(moduleName) + if err != nil { + return false, fmt.Errorf("couldn't find module path for %s, error: %v", moduleName, err) + } + if modulePath == "" { + return false, nil + } + + log.Infof("trying to load %s module", moduleName) + + err = loadModuleWithDependencies(moduleName, modulePath) + if err != nil { + return false, fmt.Errorf("couldn't load %s module, error: %v", moduleName, err) + } + return true, nil +} + +func isModuleEnabled(name string) bool { + builtin, builtinErr := isBuiltinModule(name) + state, statusErr := moduleStatus(name) + return (builtinErr == nil && builtin) || (statusErr == nil && state >= loading) +} + +func getModulePath(name string) (string, error) { + var foundPath string + skipRemainingDirs := false + + err := filepath.WalkDir( + moduleRoot, + func(path string, info fs.DirEntry, err error) error { + if skipRemainingDirs { + return fs.SkipDir + } + if err != nil { + // skip broken files + return nil + } + + if !info.Type().IsRegular() { + return nil + } + + nameFromPath := pathToName(path) + if nameFromPath == name { + foundPath = path + skipRemainingDirs = true + } + + return nil + }) + + if err != nil { + return "", err + } + + return foundPath, nil +} + +func pathToName(s string) string { + s = filepath.Base(s) + for ext := filepath.Ext(s); ext != ""; ext = filepath.Ext(s) { + s = strings.TrimSuffix(s, ext) + } + return cleanName(s) +} + +func cleanName(s string) string { + return strings.ReplaceAll(strings.TrimSpace(s), "-", "_") +} + +func isBuiltinModule(name string) (bool, error) { + f, err := os.Open(filepath.Join(moduleRoot, "/modules.builtin")) + if err != nil { + return false, err + } + defer func() { + err := f.Close() + if err != nil { + log.Errorf("failed closing modules.builtin file, %v", err) + } + }() + + var found bool + scanner := bufio.NewScanner(f) + for scanner.Scan() { + line := scanner.Text() + if pathToName(line) == name { + found = true + break + } + } + if err := scanner.Err(); err != nil { + return false, err + } + return found, nil +} + +// /proc/modules +// name | memory size | reference count | references | state: +// macvlan 28672 1 macvtap, Live 0x0000000000000000 +func moduleStatus(name string) (status, error) { + state := unknown + f, err := os.Open("/proc/modules") + if err != nil { + return state, err + } + defer func() { + err := f.Close() + if err != nil { + log.Errorf("failed closing /proc/modules file, %v", err) + } + }() + + state = unloaded + + scanner := bufio.NewScanner(f) + for scanner.Scan() { + fields := strings.Fields(scanner.Text()) + if fields[0] == name { + if fields[2] != "0" { + state = inuse + break + } + switch fields[4] { + case "Live": + state = live + case "Loading": + state = loading + case "Unloading": + state = unloading + } + break + } + } + if err := scanner.Err(); err != nil { + return state, err + } + + return state, nil +} + +func loadModuleWithDependencies(name, path string) error { + deps, err := getModuleDependencies(name) + if err != nil { + return fmt.Errorf("couldn't load list of module %s dependecies", name) + } + for _, dep := range deps { + err = loadModule(dep.name, dep.path) + if err != nil { + return fmt.Errorf("couldn't load dependecy module %s for %s", dep.name, name) + } + } + return loadModule(name, path) +} + +func loadModule(name, path string) error { + state, err := moduleStatus(name) + if err != nil { + return err + } + if state >= loading { + return nil + } + + f, err := os.Open(path) + if err != nil { + return err + } + defer func() { + err := f.Close() + if err != nil { + log.Errorf("failed closing %s file, %v", path, err) + } + }() + + // first try finit_module(2), then init_module(2) + err = unix.FinitModule(int(f.Fd()), "", 0) + if errors.Is(err, unix.ENOSYS) { + buf, err := ioutil.ReadAll(f) + if err != nil { + return err + } + return unix.InitModule(buf, "") + } + return err +} + +// getModuleDependencies returns a module dependencies +func getModuleDependencies(name string) ([]module, error) { + f, err := os.Open(filepath.Join(moduleRoot, "/modules.dep")) + if err != nil { + return nil, err + } + defer func() { + err := f.Close() + if err != nil { + log.Errorf("failed closing modules.dep file, %v", err) + } + }() + + var deps []string + scanner := bufio.NewScanner(f) + for scanner.Scan() { + line := scanner.Text() + fields := strings.Fields(line) + if pathToName(strings.TrimSuffix(fields[0], ":")) == name { + deps = fields + break + } + } + if err := scanner.Err(); err != nil { + return nil, err + } + + if len(deps) == 0 { + return nil, ErrModuleNotFound + } + deps[0] = strings.TrimSuffix(deps[0], ":") + + var modules []module + for _, v := range deps { + if pathToName(v) != name { + modules = append(modules, module{ + name: pathToName(v), + path: filepath.Join(moduleRoot, v), + }) + } + } + + return modules, nil +} diff --git a/iface/module_linux_test.go b/iface/module_linux_test.go new file mode 100644 index 000000000..62105d1a0 --- /dev/null +++ b/iface/module_linux_test.go @@ -0,0 +1,221 @@ +package iface + +import ( + "bufio" + "bytes" + "github.com/stretchr/testify/require" + "golang.org/x/sys/unix" + "io" + "io/ioutil" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestGetModuleDependencies(t *testing.T) { + testCases := []struct { + name string + module string + expected []module + }{ + { + name: "Get Single Dependency", + module: "bar", + expected: []module{ + {name: "foo", path: "kernel/a/foo.ko"}, + }, + }, + { + name: "Get Multiple Dependencies", + module: "baz", + expected: []module{ + {name: "foo", path: "kernel/a/foo.ko"}, + {name: "bar", path: "kernel/a/bar.ko"}, + }, + }, + { + name: "Get No Dependencies", + module: "foo", + expected: []module{}, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + defer resetGlobals() + _, _ = createFiles(t) + modules, err := getModuleDependencies(testCase.module) + require.NoError(t, err) + + expected := testCase.expected + for i := range expected { + expected[i].path = moduleRoot + "/" + expected[i].path + } + + require.ElementsMatchf(t, modules, expected, "returned modules should match") + }) + } +} + +func TestIsBuiltinModule(t *testing.T) { + testCases := []struct { + name string + module string + expected bool + }{ + { + name: "Built In Should Return True", + module: "foo_bi", + expected: true, + }, + { + name: "Not Built In Should Return False", + module: "not_built_in", + expected: false, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + defer resetGlobals() + _, _ = createFiles(t) + + isBuiltIn, err := isBuiltinModule(testCase.module) + require.NoError(t, err) + require.Equal(t, testCase.expected, isBuiltIn) + }) + } +} + +func TestModuleStatus(t *testing.T) { + random, err := getRandomLoadedModule(t) + if err != nil { + t.Fatal("should be able to get random module") + } + testCases := []struct { + name string + module string + shouldBeLoaded bool + }{ + { + name: "Should Return Module Loading Or Greater Status", + module: random, + shouldBeLoaded: true, + }, + { + name: "Should Return Module Unloaded Or Lower Status", + module: "not_loaded_module", + shouldBeLoaded: false, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + defer resetGlobals() + _, _ = createFiles(t) + + state, err := moduleStatus(testCase.module) + require.NoError(t, err) + if testCase.shouldBeLoaded { + require.GreaterOrEqual(t, loading, state, "moduleStatus for %s should return state loading", testCase.module) + } else { + require.Less(t, state, loading, "module should return state unloading or lower") + } + }) + } +} + +func resetGlobals() { + moduleLibDir = defaultModuleDir + moduleRoot = getModuleRoot() +} + +func createFiles(t *testing.T) (string, []module) { + writeFile := func(path, text string) { + if err := ioutil.WriteFile(path, []byte(text), 0644); err != nil { + t.Fatal(err) + } + } + var u unix.Utsname + if err := unix.Uname(&u); err != nil { + t.Fatal(err) + } + + moduleLibDir = t.TempDir() + + moduleRoot = getModuleRoot() + if err := os.Mkdir(moduleRoot, 0755); err != nil { + t.Fatal(err) + } + + text := "kernel/a/foo.ko:\n" + text += "kernel/a/bar.ko: kernel/a/foo.ko\n" + text += "kernel/a/baz.ko: kernel/a/bar.ko kernel/a/foo.ko\n" + writeFile(filepath.Join(moduleRoot, "/modules.dep"), text) + + text = "kernel/a/foo_bi.ko\n" + text += "kernel/a/bar-bi.ko.gz\n" + writeFile(filepath.Join(moduleRoot, "/modules.builtin"), text) + + modules := []module{ + {name: "foo", path: "kernel/a/foo.ko"}, + {name: "bar", path: "kernel/a/bar.ko"}, + {name: "baz", path: "kernel/a/baz.ko"}, + } + return moduleLibDir, modules +} + +func getRandomLoadedModule(t *testing.T) (string, error) { + f, err := os.Open("/proc/modules") + if err != nil { + return "", err + } + defer func() { + err := f.Close() + if err != nil { + t.Logf("failed closing /proc/modules file, %v", err) + } + }() + lines, err := lineCounter(f) + if err != nil { + return "", err + } + counter := 1 + midLine := lines / 2 + modName := "" + scanner := bufio.NewScanner(f) + for scanner.Scan() { + fields := strings.Fields(scanner.Text()) + if counter == midLine { + if fields[4] == "Unloading" { + continue + } + modName = fields[0] + break + } + counter++ + } + if scanner.Err() != nil { + return "", scanner.Err() + } + return modName, nil +} +func lineCounter(r io.Reader) (int, error) { + buf := make([]byte, 32*1024) + count := 0 + lineSep := []byte{'\n'} + + for { + c, err := r.Read(buf) + count += bytes.Count(buf[:c], lineSep) + + switch { + case err == io.EOF: + return count, nil + + case err != nil: + return count, err + } + } +}