mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-14 20:09:54 +00:00
Compare commits
22 Commits
socket-grp
...
feature/pr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a89a5c5b9c | ||
|
|
07f3c75267 | ||
|
|
9d21a2c8c1 | ||
|
|
ddc4904912 | ||
|
|
b19467e3af | ||
|
|
2bcea9d582 | ||
|
|
8ff3b06cf1 | ||
|
|
d7703767d5 | ||
|
|
7feda907ca | ||
|
|
62da482133 | ||
|
|
079bce3c2f | ||
|
|
1a09aa6715 | ||
|
|
61abf5b9ea | ||
|
|
e229050ba3 | ||
|
|
e919b2d55d | ||
|
|
a40028092d | ||
|
|
13200265d8 | ||
|
|
ed7a9363aa | ||
|
|
d56859dc5d | ||
|
|
367d37050b | ||
|
|
106527182f | ||
|
|
8e1d5b78c2 |
4
.github/workflows/golang-test-linux.yml
vendored
4
.github/workflows/golang-test-linux.yml
vendored
@@ -158,7 +158,7 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags devcert -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
@@ -229,7 +229,7 @@ jobs:
|
||||
sh -c ' \
|
||||
apk update; apk add --no-cache \
|
||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
|
||||
go test -buildvcs=false -tags "devcert privileged" -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server -e /client/testutil/privileged)
|
||||
'
|
||||
|
||||
test_relay:
|
||||
|
||||
14
Makefile
14
Makefile
@@ -1,4 +1,4 @@
|
||||
.PHONY: lint lint-all lint-install setup-hooks
|
||||
.PHONY: lint lint-all lint-install setup-hooks test-unit test-privileged
|
||||
GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
|
||||
|
||||
# Install golangci-lint locally if needed
|
||||
@@ -25,3 +25,15 @@ setup-hooks:
|
||||
@git config core.hooksPath .githooks
|
||||
@chmod +x .githooks/pre-push
|
||||
@echo "✅ Git hooks configured! Pre-push will now run 'make lint'"
|
||||
|
||||
# Host-safe unit tests: excludes the privileged-tagged tests (root / system-mutating).
|
||||
# Runs as a normal user with no sudo and leaves host networking untouched.
|
||||
test-unit:
|
||||
@go test -tags devcert -timeout 10m ./...
|
||||
|
||||
# Privileged suite: runs the `privileged`-tagged tests inside a --privileged
|
||||
# --cap-add=NET_ADMIN container via the ory/dockertest harness. Requires Docker.
|
||||
# Narrow the run with env vars, e.g.:
|
||||
# PRIV_RUN=TestNftablesManager PRIV_PKGS=./client/firewall/nftables/... make test-privileged
|
||||
test-privileged:
|
||||
@go test -tags 'devcert privileged' -timeout 30m -run TestRunPrivilegedSuiteInDocker -v ./client/testutil/privileged/...
|
||||
|
||||
@@ -3,12 +3,14 @@ package cmd
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/user"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
@@ -85,6 +87,73 @@ var persistenceCmd = &cobra.Command{
|
||||
RunE: setSyncResponsePersistence,
|
||||
}
|
||||
|
||||
var debugConfigCmd = &cobra.Command{
|
||||
Use: "config",
|
||||
Example: " netbird debug config",
|
||||
Short: "Dump the effective configuration",
|
||||
Long: "Prints the daemon's resolved configuration (after applying defaults, file, env, CLI input, and MDM policy overrides) as JSON. Includes the list of MDM-managed fields.",
|
||||
RunE: debugConfigDump,
|
||||
}
|
||||
|
||||
// debugConfigDump implements `netbird debug config`. It resolves the
|
||||
// active profile, queries the daemon for the effective configuration
|
||||
// via GetConfig, and prints the resulting GetConfigResponse as JSON
|
||||
// (via protojson with EmitUnpopulated=true so the output is stable
|
||||
// across runs and includes zero-valued fields).
|
||||
//
|
||||
// Useful for verifying MDM enforcement end-to-end: the response's
|
||||
// mDMManagedFields array is the single source of truth for "which
|
||||
// fields is the daemon currently enforcing from the MDM source", and
|
||||
// every config field side-by-side with that list confirms the merge
|
||||
// result. Secrets in the response (e.g. PreSharedKey) are already
|
||||
// redacted by the daemon-side handler.
|
||||
func debugConfigDump(cmd *cobra.Command, _ []string) error {
|
||||
pm := profilemanager.NewProfileManager()
|
||||
activeProf, err := pm.GetActiveProfile()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get active profile: %v", err)
|
||||
}
|
||||
currUser, err := user.Current()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get current user: %v", err)
|
||||
}
|
||||
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf(errCloseConnection, err)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.GetConfig(cmd.Context(), &proto.GetConfigRequest{
|
||||
ProfileName: activeProf.Name,
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get config: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
// Use protojson so well-known fields render correctly; emit defaults so
|
||||
// the operator sees every field even when zero/empty.
|
||||
m := protojson.MarshalOptions{Multiline: true, Indent: " ", EmitUnpopulated: true}
|
||||
out, err := m.Marshal(resp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal config: %w", err)
|
||||
}
|
||||
cmd.Println(string(out))
|
||||
return nil
|
||||
}
|
||||
|
||||
// debugBundle requests the daemon to create a debug bundle and prints
|
||||
// the resulting local file path and, if uploaded, the uploaded file
|
||||
// key. It uses the package flags (anonymize, system info, log file
|
||||
// count, CLI version, optional upload URL) to configure the bundle
|
||||
// request. Returns an error if the RPC fails or if the daemon reports
|
||||
// an upload failure reason.
|
||||
func debugBundle(cmd *cobra.Command, _ []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
|
||||
301
client/cmd/kubernetes.go
Normal file
301
client/cmd/kubernetes.go
Normal file
@@ -0,0 +1,301 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
KubernetesDNSSuffix = "netbird-kubeapi-proxy"
|
||||
)
|
||||
|
||||
var kubernetesCmd = &cobra.Command{
|
||||
Use: "kubernetes",
|
||||
Short: "Kubernetes cluster commands.",
|
||||
Long: "Kubernetes cluster commands.",
|
||||
}
|
||||
|
||||
var kubernetesListCmd = &cobra.Command{
|
||||
Use: "list",
|
||||
RunE: kubernetesList,
|
||||
Short: "List Kubernetes clusters.",
|
||||
Long: "List Kubernetes clusters by discovering NetBird peers running netbird-kubeapi-proxy.",
|
||||
}
|
||||
|
||||
var kubernetesWriteKubeconfigCmd = &cobra.Command{
|
||||
Use: "write-kubeconfig",
|
||||
RunE: kubernetesWriteKubeconfig,
|
||||
Args: cobra.ExactArgs(1),
|
||||
Short: "Write kubeconfig for a Kubernetes cluster.",
|
||||
Long: "Updates kubeconfig in place to allow token-less access to the Kubernetes cluster through NetBird.",
|
||||
}
|
||||
|
||||
func init() {
|
||||
kubernetesWriteKubeconfigCmd.Flags().String("kubeconfig", "", "path to kubeconfig file")
|
||||
}
|
||||
|
||||
func kubernetesList(cmd *cobra.Command, _ []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
statusResp, err := client.Status(cmd.Context(), &proto.StatusRequest{GetFullPeerStatus: true})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
kcs, err := getKubernetesClusters(cmd.Context(), statusResp.FullStatus.Peers, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(kcs) == 0 {
|
||||
cmd.Println("No Kubernetes clusters available.")
|
||||
return nil
|
||||
}
|
||||
cmd.Println("Available Kubernetes clusters:")
|
||||
for _, k := range kcs {
|
||||
cmd.Printf("\n - Name: %s\n FQDN: %s\n Version: %s\n", k.name, k.url.Host, k.version)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func kubernetesWriteKubeconfig(cmd *cobra.Command, args []string) error {
|
||||
kubeconfigPath, err := resolveKubeconfigPath(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
statusResp, err := client.Status(cmd.Context(), &proto.StatusRequest{GetFullPeerStatus: true})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
clusterName := args[0]
|
||||
kcs, err := getKubernetesClusters(cmd.Context(), statusResp.FullStatus.Peers, clusterName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(kcs) == 0 {
|
||||
return fmt.Errorf("kubernetes cluster named %s not found", clusterName)
|
||||
}
|
||||
if len(kcs) > 1 {
|
||||
return fmt.Errorf("too many Kubernetes clusters returned")
|
||||
}
|
||||
err = writeKubeconfig(kubeconfigPath, kcs[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type kubernetesCluster struct {
|
||||
name string
|
||||
url *url.URL
|
||||
version string
|
||||
}
|
||||
|
||||
func getKubernetesClusters(ctx context.Context, peers []*proto.PeerState, nameFilter string) ([]kubernetesCluster, error) {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
transport.TLSClientConfig = &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
httpClient := &http.Client{
|
||||
Transport: transport,
|
||||
}
|
||||
resolver := net.Resolver{
|
||||
// Required so both DNS records are returned.
|
||||
// https://github.com/golang/go/issues/17093
|
||||
PreferGo: true,
|
||||
}
|
||||
|
||||
kcs := []kubernetesCluster{}
|
||||
attempted := map[string]struct{}{}
|
||||
for _, peer := range peers {
|
||||
fqdns, err := resolver.LookupAddr(ctx, peer.IP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, fqdn := range fqdns {
|
||||
if _, ok := attempted[fqdn]; ok {
|
||||
continue
|
||||
}
|
||||
attempted[fqdn] = struct{}{}
|
||||
comps := strings.Split(fqdn, ".")
|
||||
if len(comps) < 2 {
|
||||
continue
|
||||
}
|
||||
if comps[1] != KubernetesDNSSuffix {
|
||||
continue
|
||||
}
|
||||
if nameFilter != "" && nameFilter != comps[0] {
|
||||
continue
|
||||
}
|
||||
clusterURL, clusterVersion, err := fingerprintClusters(ctx, httpClient, fqdn)
|
||||
if err != nil {
|
||||
log.Debugf("could not fingerprint Kubernetes cluster %s %q", fqdn, err)
|
||||
continue
|
||||
}
|
||||
kc := kubernetesCluster{
|
||||
name: comps[0],
|
||||
url: clusterURL,
|
||||
version: clusterVersion,
|
||||
}
|
||||
if nameFilter != "" {
|
||||
return []kubernetesCluster{kc}, nil
|
||||
}
|
||||
kcs = append(kcs, kc)
|
||||
}
|
||||
}
|
||||
return kcs, nil
|
||||
}
|
||||
|
||||
func fingerprintClusters(ctx context.Context, httpClient *http.Client, fqdn string) (*url.URL, string, error) {
|
||||
clusterURL, err := url.Parse("https://" + fqdn)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
versionURL, err := clusterURL.Parse("/version")
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, versionURL.String(), nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, "", fmt.Errorf("expected %d response but got %s", http.StatusOK, resp.Status)
|
||||
}
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
versionData := map[string]string{}
|
||||
err = json.Unmarshal(b, &versionData)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
version, ok := versionData["gitVersion"]
|
||||
if !ok {
|
||||
return nil, "", errors.New("no version found in response")
|
||||
}
|
||||
return clusterURL, version, nil
|
||||
}
|
||||
|
||||
func resolveKubeconfigPath(cmd *cobra.Command) (string, error) {
|
||||
if cmd.Flags().Changed("kubeconfig") {
|
||||
path, err := cmd.Flags().GetString("kubeconfig")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
if env := os.Getenv("KUBECONFIG"); env != "" {
|
||||
return env, nil
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not determine home directory: %w", err)
|
||||
}
|
||||
return filepath.Join(home, ".kube", "config"), nil
|
||||
}
|
||||
|
||||
func writeKubeconfig(kubeconfigPath string, kc kubernetesCluster) error {
|
||||
b, err := os.ReadFile(kubeconfigPath)
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return err
|
||||
}
|
||||
var cfg map[string]any
|
||||
if err := yaml.Unmarshal(b, &cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
if cfg == nil {
|
||||
cfg = map[string]any{
|
||||
"apiVersion": "v1",
|
||||
"kind": "Config",
|
||||
}
|
||||
}
|
||||
|
||||
cfg["clusters"] = appendWithName(cfg["clusters"], map[string]any{
|
||||
"name": kc.name,
|
||||
"cluster": map[string]any{
|
||||
"server": kc.url.String(),
|
||||
"insecure-skip-tls-verify": true,
|
||||
},
|
||||
})
|
||||
cfg["users"] = appendWithName(cfg["users"], map[string]any{
|
||||
"name": "netbird",
|
||||
"user": map[string]any{
|
||||
"token": "none",
|
||||
},
|
||||
})
|
||||
cfg["contexts"] = appendWithName(cfg["contexts"], map[string]any{
|
||||
"name": kc.name,
|
||||
"context": map[string]any{
|
||||
"cluster": kc.name,
|
||||
"user": "netbird",
|
||||
"namespace": "default",
|
||||
},
|
||||
})
|
||||
cfg["current-context"] = kc.name
|
||||
|
||||
out, err := yaml.Marshal(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(kubeconfigPath, out, 0o600); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func appendWithName(data any, add map[string]any) any {
|
||||
if data == nil {
|
||||
return []any{add}
|
||||
}
|
||||
v, ok := data.([]any)
|
||||
if !ok {
|
||||
return []any{add}
|
||||
}
|
||||
i := slices.IndexFunc(v, func(item any) bool {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return m["name"] == add["name"]
|
||||
})
|
||||
if i == -1 {
|
||||
return append(v, add)
|
||||
}
|
||||
v[i] = add
|
||||
return v
|
||||
}
|
||||
120
client/cmd/kubernetes_test.go
Normal file
120
client/cmd/kubernetes_test.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFingerprintClusters(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
//nolint: errcheck
|
||||
w.Write([]byte(`{"gitVersion": "foobar"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
clusterURL, clusterVersion, err := fingerprintClusters(t.Context(), srv.Client(), srv.Listener.Addr().String())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, srv.URL, clusterURL.String())
|
||||
require.Equal(t, "foobar", clusterVersion)
|
||||
}
|
||||
|
||||
func TestResolveKubeconfigPath(t *testing.T) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
t.Fatalf("could not determine home directory: %v", err)
|
||||
}
|
||||
defaultPath := filepath.Join(home, ".kube", "config")
|
||||
path, err := resolveKubeconfigPath(&cobra.Command{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, defaultPath, path)
|
||||
|
||||
flagPath := "flag-path"
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().String("kubeconfig", "", "")
|
||||
err = cmd.Flags().Set("kubeconfig", flagPath)
|
||||
require.NoError(t, err)
|
||||
path, err = resolveKubeconfigPath(cmd)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, flagPath, path)
|
||||
|
||||
envPath := "env-path"
|
||||
t.Setenv("KUBECONFIG", envPath)
|
||||
path, err = resolveKubeconfigPath(&cobra.Command{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, envPath, path)
|
||||
}
|
||||
|
||||
func TestWriteKubeconfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
existing string
|
||||
}{
|
||||
{
|
||||
name: "empty file",
|
||||
},
|
||||
{
|
||||
name: "existing content",
|
||||
existing: `apiVersion: v1
|
||||
clusters:
|
||||
- cluster:
|
||||
insecure-skip-tls-verify: true
|
||||
server: https://foobar.com
|
||||
name: foo
|
||||
current-context: test
|
||||
kind: Config
|
||||
users: []
|
||||
`,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
kubeconfigPath := filepath.Join(t.TempDir(), "config")
|
||||
err := os.WriteFile(kubeconfigPath, []byte(tt.existing), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
kc := kubernetesCluster{
|
||||
name: "foo",
|
||||
url: &url.URL{Scheme: "https", Host: "example.com"},
|
||||
}
|
||||
err = writeKubeconfig(kubeconfigPath, kc)
|
||||
require.NoError(t, err)
|
||||
|
||||
b, err := os.ReadFile(kubeconfigPath)
|
||||
require.NoError(t, err)
|
||||
expected := `apiVersion: v1
|
||||
clusters:
|
||||
- cluster:
|
||||
insecure-skip-tls-verify: true
|
||||
server: https://example.com
|
||||
name: foo
|
||||
contexts:
|
||||
- context:
|
||||
cluster: foo
|
||||
namespace: default
|
||||
user: netbird
|
||||
name: foo
|
||||
current-context: foo
|
||||
kind: Config
|
||||
users:
|
||||
- name: netbird
|
||||
user:
|
||||
token: none
|
||||
`
|
||||
require.Equal(t, expected, string(b))
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,36 +0,0 @@
|
||||
//go:build darwin || freebsd
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// peerUID returns the uid of the process on the other end of a unix socket
|
||||
// connection, read via LOCAL_PEERCRED (xucred). Note: xucred carries the uid
|
||||
// and group list but no pid, so audit on these platforms is uid-based.
|
||||
func peerUID(c net.Conn) (int, error) {
|
||||
uc, ok := c.(*net.UnixConn)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("connection is not a unix socket: %T", c)
|
||||
}
|
||||
raw, err := uc.SyscallConn()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("raw conn: %w", err)
|
||||
}
|
||||
|
||||
var cred *unix.Xucred
|
||||
var credErr error
|
||||
if err := raw.Control(func(fd uintptr) {
|
||||
cred, credErr = unix.GetsockoptXucred(int(fd), unix.SOL_LOCAL, unix.LOCAL_PEERCRED)
|
||||
}); err != nil {
|
||||
return 0, fmt.Errorf("getsockopt control: %w", err)
|
||||
}
|
||||
if credErr != nil {
|
||||
return 0, fmt.Errorf("LOCAL_PEERCRED: %w", credErr)
|
||||
}
|
||||
return int(cred.Uid), nil
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
//go:build linux
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// peerUID returns the uid of the process on the other end of a unix socket
|
||||
// connection, read from the kernel via SO_PEERCRED.
|
||||
func peerUID(c net.Conn) (int, error) {
|
||||
uc, ok := c.(*net.UnixConn)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("connection is not a unix socket: %T", c)
|
||||
}
|
||||
raw, err := uc.SyscallConn()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("raw conn: %w", err)
|
||||
}
|
||||
|
||||
var cred *unix.Ucred
|
||||
var credErr error
|
||||
if err := raw.Control(func(fd uintptr) {
|
||||
cred, credErr = unix.GetsockoptUcred(int(fd), unix.SOL_SOCKET, unix.SO_PEERCRED)
|
||||
}); err != nil {
|
||||
return 0, fmt.Errorf("getsockopt control: %w", err)
|
||||
}
|
||||
if credErr != nil {
|
||||
return 0, fmt.Errorf("SO_PEERCRED: %w", credErr)
|
||||
}
|
||||
return int(cred.Uid), nil
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
//go:build !linux && !darwin && !freebsd
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// peerUID is unimplemented on this platform, so the trust-on-first-use socket
|
||||
// migration cannot run here. Configure --socket-owner explicitly, or use
|
||||
// --disable-strict-socket. (Windows uses a TCP socket and never reaches this.)
|
||||
func peerUID(net.Conn) (int, error) {
|
||||
return 0, fmt.Errorf("peer credential check not supported on %s", runtime.GOOS)
|
||||
}
|
||||
@@ -77,8 +77,6 @@ var (
|
||||
updateSettingsDisabled bool
|
||||
captureEnabled bool
|
||||
networksDisabled bool
|
||||
socketOwner string
|
||||
strictSocketDisabled bool
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "netbird",
|
||||
@@ -97,7 +95,9 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
// Execute executes the root command.
|
||||
// Execute runs the appropriate Cobra command for the CLI.
|
||||
// If the process is the update binary it delegates to updateCmd; otherwise it runs the root command.
|
||||
// It returns any error produced during command execution.
|
||||
func Execute() error {
|
||||
if isUpdateBinary() {
|
||||
return updateCmd.Execute()
|
||||
@@ -105,6 +105,16 @@ func Execute() error {
|
||||
return rootCmd.Execute()
|
||||
}
|
||||
|
||||
// init initialises package-level defaults and configures the root
|
||||
// Cobra command tree. Sets platform-specific config / log directory
|
||||
// paths (including legacy Wiretrustee fallbacks) and a default daemon
|
||||
// address; registers persistent CLI flags (daemon address,
|
||||
// management / admin URLs, logging, setup key (file and inline,
|
||||
// mutually exclusive), preshared key, hostname, anonymise, config
|
||||
// path); attaches top-level and nested subcommands to the root
|
||||
// command; and registers `up`-specific persistent flags (external IP
|
||||
// maps, custom DNS resolver address, Rosenpass options, auto-connect
|
||||
// disabling, lazy connection).
|
||||
func init() {
|
||||
defaultConfigPathDir = "/etc/netbird/"
|
||||
defaultLogFileDir = "/var/log/netbird/"
|
||||
@@ -170,6 +180,12 @@ func init() {
|
||||
logCmd.AddCommand(logLevelCmd)
|
||||
debugCmd.AddCommand(forCmd)
|
||||
debugCmd.AddCommand(persistenceCmd)
|
||||
debugCmd.AddCommand(debugConfigCmd)
|
||||
|
||||
// kubernetes commands
|
||||
rootCmd.AddCommand(kubernetesCmd)
|
||||
kubernetesCmd.AddCommand(kubernetesListCmd)
|
||||
kubernetesCmd.AddCommand(kubernetesWriteKubeconfigCmd)
|
||||
|
||||
// profile commands
|
||||
profileCmd.AddCommand(profileListCmd)
|
||||
|
||||
@@ -57,9 +57,6 @@ func init() {
|
||||
installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
|
||||
reconfigureCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
|
||||
|
||||
serviceCmd.PersistentFlags().StringVar(&socketOwner, "socket-owner", "", "user to own the daemon control socket; restricts it to that user plus the netbird group (0660). If unset, the first client to connect claims ownership (trust-on-first-use)")
|
||||
serviceCmd.PersistentFlags().BoolVar(&strictSocketDisabled, "disable-strict-socket", false, "leave the daemon control socket world-writable (0666) instead of restricting it; set via the (root-only) service command")
|
||||
|
||||
rootCmd.AddCommand(serviceCmd)
|
||||
}
|
||||
|
||||
|
||||
@@ -4,15 +4,10 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
@@ -21,7 +16,6 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/shell"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/server"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
@@ -60,36 +54,10 @@ func (p *program) Start(svc service.Service) error {
|
||||
go func() {
|
||||
defer listen.Close()
|
||||
|
||||
srvListener := listen
|
||||
if split[0] == "unix" {
|
||||
owner := effectiveSocketOwner()
|
||||
switch {
|
||||
case strictSocketDisabled:
|
||||
// Opt-out (root-only, via service.json): leave it world-writable.
|
||||
if err := os.Chmod(split[1], 0666); err != nil {
|
||||
log.Errorf("failed setting daemon permissions: %v", split[1])
|
||||
return
|
||||
}
|
||||
case owner != "":
|
||||
// Seeded owner (flag, MDM, or persisted TOFU result): restrict
|
||||
// before serving so there is no open window.
|
||||
uid, err := lookupUser(owner)
|
||||
if err != nil {
|
||||
log.Errorf("lookup socket owner %q: %v", owner, err)
|
||||
return
|
||||
}
|
||||
if err := restrictSocket(split[1], uid); err != nil {
|
||||
log.Errorf("restrict socket to %q: %v", owner, err)
|
||||
return
|
||||
}
|
||||
default:
|
||||
// Trust-on-first-use: open the socket now; tofuListener locks it
|
||||
// to the first caller's uid on the first connection.
|
||||
if err := os.Chmod(split[1], 0666); err != nil {
|
||||
log.Errorf("failed setting daemon permissions: %v", split[1])
|
||||
return
|
||||
}
|
||||
srvListener = &tofuListener{Listener: listen, path: split[1], owner: -1}
|
||||
if err := os.Chmod(split[1], 0666); err != nil {
|
||||
log.Errorf("failed setting daemon permissions: %v", split[1])
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,180 +72,13 @@ func (p *program) Start(svc service.Service) error {
|
||||
p.serverInstanceMu.Unlock()
|
||||
|
||||
log.Printf("started daemon server: %v", split[1])
|
||||
if err := p.serv.Serve(srvListener); err != nil {
|
||||
if err := p.serv.Serve(listen); err != nil {
|
||||
log.Errorf("failed to serve daemon requests: %v", err)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func lookupUser(username string) (int, error) {
|
||||
u, err := shell.LookupWithGetent(username)
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("lookup user %s: %w", username, err)
|
||||
}
|
||||
uid, err := strconv.Atoi(u.Uid)
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("parse uid %s: %w", u.Uid, err)
|
||||
}
|
||||
return uid, nil
|
||||
}
|
||||
|
||||
// addGroup creates a system group if it doesn't already exist and returns the gid.
|
||||
// Must run as root.
|
||||
func addGroup(name string) (int, error) {
|
||||
group, err := shell.LookupGroupWithGetent(name)
|
||||
if err == nil {
|
||||
gid, err := strconv.ParseInt(group.Gid, 10, 64)
|
||||
return int(gid), err
|
||||
}
|
||||
|
||||
// looup failed, create the group
|
||||
groupadd, err := exec.LookPath("groupadd")
|
||||
if err != nil {
|
||||
// Fallback for Alpine/BusyBox systems.
|
||||
if groupadd, err = exec.LookPath("addgroup"); err != nil {
|
||||
return -1, errors.New("neither groupadd nor addgroup found")
|
||||
}
|
||||
}
|
||||
|
||||
// Use --system for a service/daemon group (no login, low GID).
|
||||
out, err := exec.Command(groupadd, "--system", name).CombinedOutput()
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("create group %q: %w: %s", name, err, out)
|
||||
}
|
||||
if group, err := shell.LookupWithGetent(name); err == nil {
|
||||
gid, err := strconv.ParseInt(group.Gid, 10, 64)
|
||||
return int(gid), err
|
||||
}
|
||||
return -1, fmt.Errorf("lookup group %q: %w", name, err)
|
||||
}
|
||||
|
||||
// restrictSocket locks the unix socket down to the owner uid plus the netbird
|
||||
// group (0660). If the group cannot be created or applied, it fails closed to
|
||||
// owner-only 0600 — it never leaves the socket world-writable.
|
||||
func restrictSocket(path string, uid int) error {
|
||||
// TODO: introduce flag to configure this (LDAP/AD usecase)
|
||||
gid, err := addGroup("netbird")
|
||||
if err != nil {
|
||||
log.Errorf("create netbird group, failing closed to owner-only 0600: %v", err)
|
||||
return chownChmod(path, uid, -1, 0600)
|
||||
}
|
||||
if err := chownChmod(path, uid, gid, 0660); err != nil {
|
||||
log.Errorf("apply netbird group to socket, failing closed to owner-only 0600: %v", err)
|
||||
return chownChmod(path, uid, -1, 0600)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// chownChmod sets ownership and mode on the socket. A gid of -1 leaves the
|
||||
// group unchanged.
|
||||
func chownChmod(path string, uid, gid int, mode os.FileMode) error {
|
||||
if err := os.Chown(path, uid, gid); err != nil {
|
||||
return fmt.Errorf("chown socket %s: %w", path, err)
|
||||
}
|
||||
if err := os.Chmod(path, mode); err != nil {
|
||||
return fmt.Errorf("chmod socket %s: %w", path, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// tofuListener implements trust-on-first-use for the daemon control socket.
|
||||
// The socket starts world-writable; the first caller's uid (read via the
|
||||
// platform peer-credential mechanism) becomes the owner. On that first
|
||||
// connection the socket is restricted (see restrictSocket) and the owner is
|
||||
// persisted so the open window never reopens on later starts. Connections that
|
||||
// raced in during the open window and are neither the owner nor root are
|
||||
// dropped. Changing the socket mode does not disturb the already-open
|
||||
// connection, so the first caller's request is served normally.
|
||||
type tofuListener struct {
|
||||
net.Listener
|
||||
path string
|
||||
mu sync.Mutex
|
||||
owner int // -1 until claimed
|
||||
}
|
||||
|
||||
func (l *tofuListener) Accept() (net.Conn, error) {
|
||||
for {
|
||||
c, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
uid, err := peerUID(c)
|
||||
if err != nil {
|
||||
log.Errorf("read peer credentials, dropping connection: %v", err)
|
||||
_ = c.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
l.mu.Lock()
|
||||
if l.owner == -1 {
|
||||
if err := restrictSocket(l.path, uid); err != nil {
|
||||
l.mu.Unlock()
|
||||
_ = c.Close()
|
||||
// Refuse to serve on a socket we could not lock down.
|
||||
return nil, fmt.Errorf("restrict socket on first connection: %w", err)
|
||||
}
|
||||
l.owner = uid
|
||||
persistSocketOwner(uid)
|
||||
log.Infof("control socket restricted to first caller (uid %d)", uid)
|
||||
l.mu.Unlock()
|
||||
return c, nil
|
||||
}
|
||||
owner := l.owner
|
||||
l.mu.Unlock()
|
||||
|
||||
// New connects are already gated by the 0660 perms set above; this only
|
||||
// drops anything that slipped in during the brief open window.
|
||||
if uid != owner && uid != 0 {
|
||||
log.Warnf("dropping non-owner connection (uid %d) during socket bootstrap", uid)
|
||||
_ = c.Close()
|
||||
continue
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
|
||||
// effectiveSocketOwner returns the configured socket owner: the --socket-owner
|
||||
// flag when set, otherwise the owner persisted by a previous TOFU migration.
|
||||
func effectiveSocketOwner() string {
|
||||
if socketOwner != "" {
|
||||
return socketOwner
|
||||
}
|
||||
params, err := loadServiceParams()
|
||||
if err != nil {
|
||||
log.Errorf("load service params for socket owner: %v", err)
|
||||
return ""
|
||||
}
|
||||
if params != nil {
|
||||
return params.SocketOwner
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// persistSocketOwner records the TOFU-selected owner (by username) so the next
|
||||
// daemon start restricts the socket immediately, with no open window.
|
||||
func persistSocketOwner(uid int) {
|
||||
u, err := user.LookupId(strconv.Itoa(uid))
|
||||
if err != nil {
|
||||
log.Errorf("resolve uid %d to username for persistence: %v", uid, err)
|
||||
return
|
||||
}
|
||||
params, err := loadServiceParams()
|
||||
if err != nil {
|
||||
log.Errorf("load service params to persist socket owner: %v", err)
|
||||
return
|
||||
}
|
||||
if params == nil {
|
||||
params = currentServiceParams()
|
||||
}
|
||||
params.SocketOwner = u.Username
|
||||
if err := saveServiceParams(params); err != nil {
|
||||
log.Errorf("persist socket owner: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *program) Stop(srv service.Service) error {
|
||||
p.serverInstanceMu.Lock()
|
||||
if p.serverInstance != nil {
|
||||
|
||||
@@ -67,14 +67,6 @@ func buildServiceArguments() []string {
|
||||
args = append(args, "--disable-networks")
|
||||
}
|
||||
|
||||
if socketOwner != "" {
|
||||
args = append(args, "--socket-owner", socketOwner)
|
||||
}
|
||||
|
||||
if strictSocketDisabled {
|
||||
args = append(args, "--disable-strict-socket")
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
||||
|
||||
@@ -135,8 +127,6 @@ var installCmd = &cobra.Command{
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Printf("SUDO_UID: %s\n", os.Getenv("SUDO_UID"))
|
||||
|
||||
if err := loadAndApplyServiceParams(cmd); err != nil {
|
||||
cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err)
|
||||
}
|
||||
|
||||
@@ -30,8 +30,6 @@ type serviceParams struct {
|
||||
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
|
||||
EnableCapture bool `json:"enable_capture,omitempty"`
|
||||
DisableNetworks bool `json:"disable_networks,omitempty"`
|
||||
SocketOwner string `json:"socket_owner,omitempty"`
|
||||
DisableStrictSocket bool `json:"disable_strict_socket,omitempty"`
|
||||
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
|
||||
}
|
||||
|
||||
@@ -84,8 +82,6 @@ func currentServiceParams() *serviceParams {
|
||||
DisableUpdateSettings: updateSettingsDisabled,
|
||||
EnableCapture: captureEnabled,
|
||||
DisableNetworks: networksDisabled,
|
||||
SocketOwner: socketOwner,
|
||||
DisableStrictSocket: strictSocketDisabled,
|
||||
}
|
||||
|
||||
if len(serviceEnvVars) > 0 {
|
||||
@@ -158,14 +154,6 @@ func applyServiceParams(cmd *cobra.Command, params *serviceParams) {
|
||||
networksDisabled = params.DisableNetworks
|
||||
}
|
||||
|
||||
if !serviceCmd.PersistentFlags().Changed("socket-owner") {
|
||||
socketOwner = params.SocketOwner
|
||||
}
|
||||
|
||||
if !serviceCmd.PersistentFlags().Changed("disable-strict-socket") {
|
||||
strictSocketDisabled = params.DisableStrictSocket
|
||||
}
|
||||
|
||||
applyServiceEnvParams(cmd, params)
|
||||
}
|
||||
|
||||
|
||||
196
client/cmd/service_privileged_test.go
Normal file
196
client/cmd/service_privileged_test.go
Normal file
@@ -0,0 +1,196 @@
|
||||
//go:build privileged
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
serviceStartTimeout = 10 * time.Second
|
||||
serviceStopTimeout = 5 * time.Second
|
||||
statusPollInterval = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
// waitForServiceStatus waits for service to reach expected status with timeout
|
||||
func waitForServiceStatus(expectedStatus service.Status, timeout time.Duration) (bool, error) {
|
||||
cfg, err := newSVCConfig()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
ctx, timeoutCancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer timeoutCancel()
|
||||
|
||||
ticker := time.NewTicker(statusPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false, fmt.Errorf("timeout waiting for service status %v", expectedStatus)
|
||||
case <-ticker.C:
|
||||
status, err := s.Status()
|
||||
if err != nil {
|
||||
// Continue polling on transient errors
|
||||
continue
|
||||
}
|
||||
if status == expectedStatus {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestServiceLifecycle tests the complete service lifecycle
|
||||
func TestServiceLifecycle(t *testing.T) {
|
||||
// TODO: Add support for Windows and macOS
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||
t.Skipf("Skipping service lifecycle test on unsupported OS: %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
if os.Getenv("CONTAINER") == "true" {
|
||||
t.Skip("Skipping service lifecycle test in container environment")
|
||||
}
|
||||
|
||||
originalServiceName := serviceName
|
||||
serviceName = "netbirdtest" + fmt.Sprintf("%d", time.Now().Unix())
|
||||
defer func() {
|
||||
serviceName = originalServiceName
|
||||
}()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
configPath = fmt.Sprintf("%s/netbird-test-config.json", tempDir)
|
||||
logLevel = "info"
|
||||
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
|
||||
|
||||
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
|
||||
t.Cleanup(func() {
|
||||
cfg, err := newSVCConfig()
|
||||
if err != nil {
|
||||
t.Errorf("cleanup: create service config: %v", err)
|
||||
return
|
||||
}
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
if err != nil {
|
||||
t.Errorf("cleanup: create service: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// If the subtests already cleaned up, there's nothing to do.
|
||||
if _, err := s.Status(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.Stop(); err != nil {
|
||||
t.Errorf("cleanup: stop service: %v", err)
|
||||
}
|
||||
if err := s.Uninstall(); err != nil {
|
||||
t.Errorf("cleanup: uninstall service: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Install", func(t *testing.T) {
|
||||
installCmd.SetContext(ctx)
|
||||
err := installCmd.RunE(installCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg, err := newSVCConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
status, err := s.Status()
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, service.StatusUnknown, status)
|
||||
})
|
||||
|
||||
t.Run("Start", func(t *testing.T) {
|
||||
startCmd.SetContext(ctx)
|
||||
err := startCmd.RunE(startCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, running)
|
||||
})
|
||||
|
||||
t.Run("Restart", func(t *testing.T) {
|
||||
restartCmd.SetContext(ctx)
|
||||
err := restartCmd.RunE(restartCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, running)
|
||||
})
|
||||
|
||||
t.Run("Reconfigure", func(t *testing.T) {
|
||||
originalLogLevel := logLevel
|
||||
logLevel = "debug"
|
||||
defer func() {
|
||||
logLevel = originalLogLevel
|
||||
}()
|
||||
|
||||
reconfigureCmd.SetContext(ctx)
|
||||
err := reconfigureCmd.RunE(reconfigureCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, running)
|
||||
})
|
||||
|
||||
t.Run("Stop", func(t *testing.T) {
|
||||
stopCmd.SetContext(ctx)
|
||||
err := stopCmd.RunE(stopCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
stopped, err := waitForServiceStatus(service.StatusStopped, serviceStopTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, stopped)
|
||||
})
|
||||
|
||||
t.Run("Uninstall", func(t *testing.T) {
|
||||
uninstallCmd.SetContext(ctx)
|
||||
err := uninstallCmd.RunE(uninstallCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg, err := newSVCConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = s.Status()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
@@ -1,16 +1,12 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -31,186 +27,6 @@ func TestMain(m *testing.M) {
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
const (
|
||||
serviceStartTimeout = 10 * time.Second
|
||||
serviceStopTimeout = 5 * time.Second
|
||||
statusPollInterval = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
// waitForServiceStatus waits for service to reach expected status with timeout
|
||||
func waitForServiceStatus(expectedStatus service.Status, timeout time.Duration) (bool, error) {
|
||||
cfg, err := newSVCConfig()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
ctx, timeoutCancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer timeoutCancel()
|
||||
|
||||
ticker := time.NewTicker(statusPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false, fmt.Errorf("timeout waiting for service status %v", expectedStatus)
|
||||
case <-ticker.C:
|
||||
status, err := s.Status()
|
||||
if err != nil {
|
||||
// Continue polling on transient errors
|
||||
continue
|
||||
}
|
||||
if status == expectedStatus {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestServiceLifecycle tests the complete service lifecycle
|
||||
func TestServiceLifecycle(t *testing.T) {
|
||||
// TODO: Add support for Windows and macOS
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||
t.Skipf("Skipping service lifecycle test on unsupported OS: %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
if os.Getenv("CONTAINER") == "true" {
|
||||
t.Skip("Skipping service lifecycle test in container environment")
|
||||
}
|
||||
|
||||
originalServiceName := serviceName
|
||||
serviceName = "netbirdtest" + fmt.Sprintf("%d", time.Now().Unix())
|
||||
defer func() {
|
||||
serviceName = originalServiceName
|
||||
}()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
configPath = fmt.Sprintf("%s/netbird-test-config.json", tempDir)
|
||||
logLevel = "info"
|
||||
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
|
||||
|
||||
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
|
||||
t.Cleanup(func() {
|
||||
cfg, err := newSVCConfig()
|
||||
if err != nil {
|
||||
t.Errorf("cleanup: create service config: %v", err)
|
||||
return
|
||||
}
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
if err != nil {
|
||||
t.Errorf("cleanup: create service: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// If the subtests already cleaned up, there's nothing to do.
|
||||
if _, err := s.Status(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.Stop(); err != nil {
|
||||
t.Errorf("cleanup: stop service: %v", err)
|
||||
}
|
||||
if err := s.Uninstall(); err != nil {
|
||||
t.Errorf("cleanup: uninstall service: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Install", func(t *testing.T) {
|
||||
installCmd.SetContext(ctx)
|
||||
err := installCmd.RunE(installCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg, err := newSVCConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
status, err := s.Status()
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, service.StatusUnknown, status)
|
||||
})
|
||||
|
||||
t.Run("Start", func(t *testing.T) {
|
||||
startCmd.SetContext(ctx)
|
||||
err := startCmd.RunE(startCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, running)
|
||||
})
|
||||
|
||||
t.Run("Restart", func(t *testing.T) {
|
||||
restartCmd.SetContext(ctx)
|
||||
err := restartCmd.RunE(restartCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, running)
|
||||
})
|
||||
|
||||
t.Run("Reconfigure", func(t *testing.T) {
|
||||
originalLogLevel := logLevel
|
||||
logLevel = "debug"
|
||||
defer func() {
|
||||
logLevel = originalLogLevel
|
||||
}()
|
||||
|
||||
reconfigureCmd.SetContext(ctx)
|
||||
err := reconfigureCmd.RunE(reconfigureCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, running)
|
||||
})
|
||||
|
||||
t.Run("Stop", func(t *testing.T) {
|
||||
stopCmd.SetContext(ctx)
|
||||
err := stopCmd.RunE(stopCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
stopped, err := waitForServiceStatus(service.StatusStopped, serviceStopTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, stopped)
|
||||
})
|
||||
|
||||
t.Run("Uninstall", func(t *testing.T) {
|
||||
uninstallCmd.SetContext(ctx)
|
||||
err := uninstallCmd.RunE(uninstallCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg, err := newSVCConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = s.Status()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestServiceEnvVars tests environment variable parsing
|
||||
func TestServiceEnvVars(t *testing.T) {
|
||||
tests := []struct {
|
||||
|
||||
@@ -279,6 +279,10 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
|
||||
select {
|
||||
case <-startCtx.Done():
|
||||
// Cancel the client context before stopping: Engine.Start blocks on the
|
||||
// signal stream while holding the engine mutex and only unblocks on
|
||||
// cancellation. Stopping first would deadlock on that mutex.
|
||||
cancel()
|
||||
if stopErr := client.Stop(); stopErr != nil {
|
||||
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
|
||||
}
|
||||
@@ -442,8 +446,8 @@ func (c *Client) Expose(ctx context.Context, req ExposeRequest) (*ExposeSession,
|
||||
|
||||
// IdentityForIP looks up a remote peer by its tunnel IP using the
|
||||
// embedded client's status recorder. Returns the peer's WireGuard public
|
||||
// key and FQDN. ok=false means the IP isn't in this client's peer
|
||||
// roster — callers should treat that as "unknown peer".
|
||||
// key and FQDN. ok=false means the IP doesn't belong to an active peer
|
||||
// — offline roster peers are treated as unknown, same as foreign IPs.
|
||||
func (c *Client) IdentityForIP(ip netip.Addr) (pubKey, fqdn string, ok bool) {
|
||||
if !ip.IsValid() || c.recorder == nil {
|
||||
return "", "", false
|
||||
|
||||
168
client/embed/embed_test.go
Normal file
168
client/embed/embed_test.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package embed
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
mgmt "github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const testSetupKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||
|
||||
// TestClientStartTimeoutRollback reproduces a deadlock between Engine.Start and
|
||||
// Engine.Stop. The signal endpoint accepts gRPC connections but never serves the
|
||||
// SignalExchange service, so Engine.Start parks in WaitStreamConnected while
|
||||
// holding the engine mutex. When the Start context expires, the rollback path
|
||||
// calls ConnectClient.Stop, which must not block forever acquiring that mutex.
|
||||
func TestClientStartTimeoutRollback(t *testing.T) {
|
||||
signalAddr := startBlackholeSignal(t)
|
||||
mgmAddr := startManagement(t, signalAddr)
|
||||
|
||||
wgPort := 0
|
||||
client, err := New(Options{
|
||||
DeviceName: "embed-rollback-test",
|
||||
SetupKey: testSetupKey,
|
||||
ManagementURL: "http://" + mgmAddr,
|
||||
WireguardPort: &wgPort,
|
||||
})
|
||||
require.NoError(t, err, "embed client creation must succeed")
|
||||
|
||||
startCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
startErr := make(chan error, 1)
|
||||
go func() {
|
||||
startErr <- client.Start(startCtx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-startErr:
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
case <-time.After(60 * time.Second):
|
||||
t.Fatal("client.Start did not return after its context expired: Engine.Stop deadlocked against Engine.Start waiting for the signal stream")
|
||||
}
|
||||
}
|
||||
|
||||
// startBlackholeSignal starts a gRPC server without the SignalExchange service
|
||||
// registered. Connections succeed, but the signal stream can never be
|
||||
// established, which keeps Engine.Start parked in WaitStreamConnected.
|
||||
func startBlackholeSignal(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
s := grpc.NewServer()
|
||||
go func() {
|
||||
if err := s.Serve(lis); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
t.Cleanup(s.Stop)
|
||||
|
||||
return lis.Addr().String()
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, signalAddr string) string {
|
||||
t.Helper()
|
||||
|
||||
cfg := &config.Config{
|
||||
Stuns: []*config.Host{},
|
||||
TURNConfig: &config.TURNConfig{},
|
||||
Relay: &config.Relay{
|
||||
Addresses: []string{"127.0.0.1:1234"},
|
||||
CredentialsTTL: util.Duration{Duration: time.Hour},
|
||||
Secret: "222222222222222222",
|
||||
},
|
||||
Signal: &config.Host{
|
||||
Proto: "http",
|
||||
URI: signalAddr,
|
||||
},
|
||||
Datadir: t.TempDir(),
|
||||
HttpConfig: nil,
|
||||
}
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
s := grpc.NewServer()
|
||||
|
||||
testStore, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", cfg.Datadir)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
|
||||
permissionsManager := permissions.NewManager(testStore)
|
||||
peersManager := peers.NewManager(testStore, permissionsManager)
|
||||
jobManager := job.NewJobManager(nil, testStore, peersManager)
|
||||
|
||||
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
iv, err := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
||||
require.NoError(t, err)
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
settingsMockManager.EXPECT().
|
||||
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
settingsMockManager.EXPECT().
|
||||
GetExtraSettings(gomock.Any(), gomock.Any()).
|
||||
Return(&types.ExtraSettings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := mgmt.NewAccountRequestBuffer(context.Background(), testStore)
|
||||
networkMapController := controller.NewController(context.Background(), testStore, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(testStore, peersManager), cfg)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), cfg, testStore, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
||||
require.NoError(t, err)
|
||||
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, cfg.TURNConfig, cfg.Relay, settingsMockManager, groupsManager)
|
||||
require.NoError(t, err)
|
||||
|
||||
mgmtServer, err := nbgrpc.NewServer(cfg, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||
require.NoError(t, err)
|
||||
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
||||
|
||||
go func() {
|
||||
if err := s.Serve(lis); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
t.Cleanup(s.Stop)
|
||||
|
||||
return lis.Addr().String()
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package iptables
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net"
|
||||
"slices"
|
||||
|
||||
@@ -421,12 +422,17 @@ func (m *aclManager) updateState() {
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
// Clone the maps so the persisted state holds a private snapshot. The
|
||||
// live maps keep being mutated by subsequent rule operations while the
|
||||
// state manager marshals the state from its periodic-save goroutine.
|
||||
// Sharing them by reference races the two and aborts the process with a
|
||||
// concurrent map iteration and write.
|
||||
if m.v6 {
|
||||
currentState.ACLEntries6 = m.entries
|
||||
currentState.ACLIPsetStore6 = m.ipsetStore
|
||||
currentState.ACLEntries6 = maps.Clone(m.entries)
|
||||
currentState.ACLIPsetStore6 = m.ipsetStore.clone()
|
||||
} else {
|
||||
currentState.ACLEntries = m.entries
|
||||
currentState.ACLIPsetStore = m.ipsetStore
|
||||
currentState.ACLEntries = maps.Clone(m.entries)
|
||||
currentState.ACLIPsetStore = m.ipsetStore.clone()
|
||||
}
|
||||
|
||||
if err := m.stateManager.UpdateState(currentState); err != nil {
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build privileged
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
|
||||
@@ -4,6 +4,7 @@ package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -749,11 +750,17 @@ func (r *router) updateState() {
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
// Clone the rule map so the persisted state holds a private snapshot. The
|
||||
// live map keeps being mutated by subsequent rule operations while the
|
||||
// state manager marshals the state from its periodic-save goroutine.
|
||||
// Sharing it by reference races the two and aborts the process with a
|
||||
// concurrent map iteration and write. The ipset counter guards itself
|
||||
// during marshaling, so it can be shared directly.
|
||||
if r.v6 {
|
||||
currentState.RouteRules6 = r.rules
|
||||
currentState.RouteRules6 = maps.Clone(r.rules)
|
||||
currentState.RouteIPsetCounter6 = r.ipsetCounter
|
||||
} else {
|
||||
currentState.RouteRules = r.rules
|
||||
currentState.RouteRules = maps.Clone(r.rules)
|
||||
currentState.RouteIPsetCounter = r.ipsetCounter
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build !android && privileged
|
||||
|
||||
package iptables
|
||||
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package iptables
|
||||
|
||||
import "encoding/json"
|
||||
import (
|
||||
"encoding/json"
|
||||
"maps"
|
||||
)
|
||||
|
||||
type ipList struct {
|
||||
ips map[string]struct{}
|
||||
@@ -19,6 +22,14 @@ func (s *ipList) addIP(ip string) {
|
||||
s.ips[ip] = struct{}{}
|
||||
}
|
||||
|
||||
// clone returns a deep copy of the ipList with its own ips map.
|
||||
func (s *ipList) clone() *ipList {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return &ipList{ips: maps.Clone(s.ips)}
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler
|
||||
func (s *ipList) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
@@ -55,6 +66,19 @@ func newIpsetStore() *ipsetStore {
|
||||
}
|
||||
}
|
||||
|
||||
// clone returns a deep copy of the ipsetStore with its own ipsets map and
|
||||
// independent ipList entries.
|
||||
func (s *ipsetStore) clone() *ipsetStore {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := &ipsetStore{ipsets: make(map[string]*ipList, len(s.ipsets))}
|
||||
for name, list := range s.ipsets {
|
||||
cloned.ipsets[name] = list.clone()
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) {
|
||||
r, ok := s.ipsets[ipsetName]
|
||||
return r, ok
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build privileged
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build !android && privileged
|
||||
|
||||
package nftables
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build privileged
|
||||
|
||||
package iface
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build linux && !android
|
||||
//go:build linux && !android && privileged
|
||||
|
||||
package wgproxy
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !linux
|
||||
//go:build !linux || !privileged
|
||||
|
||||
package wgproxy
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build linux && !android
|
||||
//go:build linux && !android && privileged
|
||||
|
||||
package wgproxy
|
||||
|
||||
@@ -26,64 +26,6 @@ func compareUDPAddr(addr1, addr2 net.Addr) bool {
|
||||
return udpAddr1.IP.Equal(udpAddr2.IP) && udpAddr1.Port == udpAddr2.Port
|
||||
}
|
||||
|
||||
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
|
||||
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
|
||||
wgPort := 51850
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("192.168.0.56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
|
||||
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
|
||||
wgPort := 51851
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("fe80::56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// TestRedirectAs_UDP_IPv4 tests RedirectAs with UDP proxy using IPv4 addresses
|
||||
func TestRedirectAs_UDP_IPv4(t *testing.T) {
|
||||
wgPort := 51852
|
||||
@@ -256,6 +198,64 @@ func testRedirectAs(t *testing.T, proxy Proxy, wgPort int, nbAddr, p2pEndpoint *
|
||||
}
|
||||
}
|
||||
|
||||
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
|
||||
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
|
||||
wgPort := 51850
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("192.168.0.56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
|
||||
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
|
||||
wgPort := 51851
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("fe80::56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// TestRedirectAs_Multiple_Switches tests switching between multiple endpoints
|
||||
func TestRedirectAs_Multiple_Switches(t *testing.T) {
|
||||
wgPort := 51856
|
||||
|
||||
@@ -516,6 +516,14 @@ func (g *BundleGenerator) addConfig() error {
|
||||
}
|
||||
}
|
||||
|
||||
// Surface the set of MDM-enforced keys so a support engineer reading
|
||||
// the bundle can tell which field values are user-set vs MDM-overridden.
|
||||
// Same semantics as the mDMManagedFields list returned by the
|
||||
// GetConfig RPC consumed by `netbird debug config`.
|
||||
if managed := g.internalConfig.Policy().ManagedKeys(); len(managed) > 0 {
|
||||
configContent.WriteString(fmt.Sprintf("MDMManagedFields: %v\n", managed))
|
||||
}
|
||||
|
||||
configReader := strings.NewReader(configContent.String())
|
||||
if err := g.addFileToZip(configReader, "config.txt"); err != nil {
|
||||
return fmt.Errorf("add config file to zip: %w", err)
|
||||
|
||||
@@ -843,6 +843,7 @@ func TestAddConfig_AllFieldsCovered(t *testing.T) {
|
||||
"PreSharedKey": "sensitive: WireGuard pre-shared key",
|
||||
"SSHKey": "sensitive: SSH private key",
|
||||
"ClientCertKeyPair": "non-config: parsed cert pair, not serialized",
|
||||
"policy": "non-config: in-memory MDM policy snapshot, surfaced via Config.Policy() / GetConfigResponse.MDMManagedFields",
|
||||
}
|
||||
|
||||
mURL, _ := url.Parse("https://api.example.com:443")
|
||||
|
||||
@@ -482,7 +482,7 @@ func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16,
|
||||
// completely when every proxy peer is offline (the upstream may still
|
||||
// be reachable some other way, or the peerstore may be stale).
|
||||
func (d *Resolver) filterDisconnectedPeerAnswers(logger *log.Entry, question dns.Question, records []dns.RR) []dns.RR {
|
||||
if len(records) == 0 {
|
||||
if len(records) < 2 {
|
||||
return records
|
||||
}
|
||||
d.mu.RLock()
|
||||
|
||||
@@ -2738,6 +2738,17 @@ func TestLocalResolver_FilterDisconnectedPeerAnswers(t *testing.T) {
|
||||
connByIP: nil,
|
||||
wantInOrder: []string{"100.64.0.10", "100.64.0.11"},
|
||||
},
|
||||
{
|
||||
// A single answer is never filtered: dropping it would only
|
||||
// trigger the empty-answer escape hatch, so the fast path
|
||||
// returns it untouched.
|
||||
name: "single disconnected answer passes through",
|
||||
records: []nbdns.SimpleRecord{disconnectedRec},
|
||||
connByIP: map[string]ipState{
|
||||
"100.64.0.11": {known: true, connected: false},
|
||||
},
|
||||
wantInOrder: []string{"100.64.0.11"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
|
||||
@@ -14,6 +14,10 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// errNoSuitableAddress mirrors the unexported error string the net package
|
||||
// uses when a resolved host has no addresses of the requested family.
|
||||
const errNoSuitableAddress = "no suitable address found"
|
||||
|
||||
// GenerateRequestID creates a random 8-character hex string for request tracing.
|
||||
func GenerateRequestID() string {
|
||||
bytes := make([]byte, 4)
|
||||
@@ -126,6 +130,14 @@ func LookupIP(ctx context.Context, r resolver, network, host string, qtype uint1
|
||||
}
|
||||
|
||||
func getRcodeForError(ctx context.Context, r resolver, host string, qtype uint16, err error) int {
|
||||
// The net package returns this AddrError when the host resolves but has
|
||||
// no addresses of the requested family. The domain exists, so answer
|
||||
// NODATA instead of SERVFAIL.
|
||||
var addrErr *net.AddrError
|
||||
if errors.As(err, &addrErr) && addrErr.Err == errNoSuitableAddress {
|
||||
return dns.RcodeSuccess
|
||||
}
|
||||
|
||||
var dnsErr *net.DNSError
|
||||
if !errors.As(err, &dnsErr) {
|
||||
return dns.RcodeServerFailure
|
||||
|
||||
122
client/internal/dns/resutil/resolve_test.go
Normal file
122
client/internal/dns/resutil/resolve_test.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package resutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockResolver struct {
|
||||
// results maps network ("ip4"/"ip6") to the lookup outcome.
|
||||
results map[string]mockLookup
|
||||
}
|
||||
|
||||
type mockLookup struct {
|
||||
ips []netip.Addr
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockResolver) LookupNetIP(_ context.Context, network, _ string) ([]netip.Addr, error) {
|
||||
res, ok := m.results[network]
|
||||
if !ok {
|
||||
return nil, errors.New("unexpected network: " + network)
|
||||
}
|
||||
return res.ips, res.err
|
||||
}
|
||||
|
||||
func TestLookupIP_Success(t *testing.T) {
|
||||
r := &mockResolver{results: map[string]mockLookup{
|
||||
"ip4": {ips: []netip.Addr{netip.MustParseAddr("::ffff:192.0.2.1")}},
|
||||
}}
|
||||
|
||||
result := LookupIP(context.Background(), r, "ip4", "example.com.", dns.TypeA)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, result.Rcode, "successful lookup should return NOERROR")
|
||||
require.Len(t, result.IPs, 1, "should return the resolved address")
|
||||
assert.Equal(t, netip.MustParseAddr("192.0.2.1"), result.IPs[0], "v4-mapped address should be unmapped")
|
||||
}
|
||||
|
||||
func TestLookupIP_NoSuitableAddress(t *testing.T) {
|
||||
// The net package returns this AddrError when the host resolves but has
|
||||
// no addresses of the requested family (e.g. AAAA query for a v4-only
|
||||
// hosts file entry). The domain exists, so this is NODATA, not SERVFAIL.
|
||||
r := &mockResolver{results: map[string]mockLookup{
|
||||
"ip6": {err: &net.AddrError{Err: "no suitable address found", Addr: "example.com."}},
|
||||
}}
|
||||
|
||||
result := LookupIP(context.Background(), r, "ip6", "example.com.", dns.TypeAAAA)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, result.Rcode, "no suitable address should map to NODATA")
|
||||
assert.Empty(t, result.IPs, "NODATA response should carry no addresses")
|
||||
}
|
||||
|
||||
// TestErrNoSuitableAddressMatchesNetPackage pins our copy of the error string
|
||||
// to what the net package actually emits. A literal IP of the wrong family
|
||||
// takes the same filterAddrList path as a resolved hostname, without network
|
||||
// access.
|
||||
func TestErrNoSuitableAddressMatchesNetPackage(t *testing.T) {
|
||||
_, err := (&net.Resolver{}).LookupNetIP(context.Background(), "ip6", "192.0.2.1")
|
||||
require.Error(t, err)
|
||||
|
||||
var addrErr *net.AddrError
|
||||
require.ErrorAs(t, err, &addrErr, "wrong-family lookup should return AddrError")
|
||||
assert.Equal(t, errNoSuitableAddress, addrErr.Err, "net package error string should match our constant")
|
||||
}
|
||||
|
||||
func TestLookupIP_OtherAddrError(t *testing.T) {
|
||||
r := &mockResolver{results: map[string]mockLookup{
|
||||
"ip4": {err: &net.AddrError{Err: "some other address problem", Addr: "example.com."}},
|
||||
}}
|
||||
|
||||
result := LookupIP(context.Background(), r, "ip4", "example.com.", dns.TypeA)
|
||||
|
||||
assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "unrecognized AddrError should map to SERVFAIL")
|
||||
}
|
||||
|
||||
func TestLookupIP_NotFoundNXDomain(t *testing.T) {
|
||||
r := &mockResolver{results: map[string]mockLookup{
|
||||
"ip4": {err: &net.DNSError{Err: "no such host", Name: "example.com.", IsNotFound: true}},
|
||||
"ip6": {err: &net.DNSError{Err: "no such host", Name: "example.com.", IsNotFound: true}},
|
||||
}}
|
||||
|
||||
result := LookupIP(context.Background(), r, "ip4", "example.com.", dns.TypeA)
|
||||
|
||||
assert.Equal(t, dns.RcodeNameError, result.Rcode, "not found for both families should map to NXDOMAIN")
|
||||
}
|
||||
|
||||
func TestLookupIP_NotFoundNoData(t *testing.T) {
|
||||
r := &mockResolver{results: map[string]mockLookup{
|
||||
"ip6": {err: &net.DNSError{Err: "no such host", Name: "example.com.", IsNotFound: true}},
|
||||
"ip4": {ips: []netip.Addr{netip.MustParseAddr("192.0.2.1")}},
|
||||
}}
|
||||
|
||||
result := LookupIP(context.Background(), r, "ip6", "example.com.", dns.TypeAAAA)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, result.Rcode, "not found with the other family present should map to NODATA")
|
||||
}
|
||||
|
||||
func TestLookupIP_GenericError(t *testing.T) {
|
||||
r := &mockResolver{results: map[string]mockLookup{
|
||||
"ip4": {err: errors.New("connection refused")},
|
||||
}}
|
||||
|
||||
result := LookupIP(context.Background(), r, "ip4", "example.com.", dns.TypeA)
|
||||
|
||||
assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "generic error should map to SERVFAIL")
|
||||
}
|
||||
|
||||
func TestLookupIP_DNSErrorNotIsNotFound(t *testing.T) {
|
||||
r := &mockResolver{results: map[string]mockLookup{
|
||||
"ip4": {err: &net.DNSError{Err: "server misbehaving", Name: "example.com.", IsTemporary: true}},
|
||||
}}
|
||||
|
||||
result := LookupIP(context.Background(), r, "ip4", "example.com.", dns.TypeA)
|
||||
|
||||
assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "upstream failure should map to SERVFAIL")
|
||||
}
|
||||
@@ -777,13 +777,24 @@ func (s *DefaultServer) applyHostConfig() {
|
||||
// context is released rather than leaked until GC.
|
||||
func (s *DefaultServer) registerFallback() {
|
||||
originalNameservers := s.hostManager.getOriginalNameservers()
|
||||
if len(originalNameservers) == 0 {
|
||||
|
||||
serverIP := s.service.RuntimeIP()
|
||||
var servers []netip.AddrPort
|
||||
for _, ns := range originalNameservers {
|
||||
if ns == serverIP {
|
||||
log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, serverIP)
|
||||
continue
|
||||
}
|
||||
servers = append(servers, netip.AddrPortFrom(ns, DefaultPort))
|
||||
}
|
||||
|
||||
if len(servers) == 0 {
|
||||
log.Debugf("no fallback upstreams to register; clearing PriorityFallback handler")
|
||||
s.clearFallback()
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("registering original nameservers %v as upstream handlers with priority %d", originalNameservers, PriorityFallback)
|
||||
log.Infof("registering original nameservers %v as upstream handlers with priority %d", servers, PriorityFallback)
|
||||
|
||||
handler, err := newUpstreamResolver(
|
||||
s.ctx,
|
||||
@@ -797,11 +808,6 @@ func (s *DefaultServer) registerFallback() {
|
||||
return
|
||||
}
|
||||
handler.selectedRoutes = s.selectedRoutes
|
||||
|
||||
var servers []netip.AddrPort
|
||||
for _, ns := range originalNameservers {
|
||||
servers = append(servers, netip.AddrPortFrom(ns, DefaultPort))
|
||||
}
|
||||
handler.addRace(servers)
|
||||
|
||||
prev := s.fallbackHandler
|
||||
|
||||
501
client/internal/dns/server_privileged_test.go
Normal file
501
client/internal/dns/server_privileged_test.go
Normal file
@@ -0,0 +1,501 @@
|
||||
//go:build privileged
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/local"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
func generateDummyHandler(d string, servers []nbdns.NameServer) *upstreamResolverBase {
|
||||
var srvs []netip.AddrPort
|
||||
for _, srv := range servers {
|
||||
srvs = append(srvs, srv.AddrPort())
|
||||
}
|
||||
u := &upstreamResolverBase{
|
||||
domain: domain.Domain(d),
|
||||
cancel: func() {},
|
||||
}
|
||||
u.addRace(srvs)
|
||||
return u
|
||||
}
|
||||
|
||||
func TestUpdateDNSServer(t *testing.T) {
|
||||
|
||||
nameServers := []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.4.4"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
}
|
||||
|
||||
dummyHandler := local.NewResolver()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
initUpstreamMap registeredHandlerMap
|
||||
initLocalZones []nbdns.CustomZone
|
||||
initSerial uint64
|
||||
inputSerial uint64
|
||||
inputUpdate nbdns.Config
|
||||
shouldFail bool
|
||||
expectedUpstreamMap registeredHandlerMap
|
||||
expectedLocalQs []dns.Question
|
||||
}{
|
||||
{
|
||||
name: "Initial Config Should Succeed",
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"netbird.io"},
|
||||
NameServers: nameServers,
|
||||
},
|
||||
{
|
||||
NameServers: nameServers,
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
||||
domain: "netbird.io",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
dummyHandler.ID(): handlerWrapper{
|
||||
domain: "netbird.cloud",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityLocal,
|
||||
},
|
||||
generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
||||
domain: nbdns.RootZone,
|
||||
handler: dummyHandler,
|
||||
priority: PriorityDefault,
|
||||
},
|
||||
},
|
||||
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
|
||||
},
|
||||
{
|
||||
name: "New Config Should Succeed",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: "netbird.cloud",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"netbird.io"},
|
||||
NameServers: nameServers,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
||||
domain: "netbird.io",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
"local-resolver": handlerWrapper{
|
||||
domain: "netbird.cloud",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityLocal,
|
||||
},
|
||||
},
|
||||
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
|
||||
},
|
||||
{
|
||||
name: "Smaller Config Serial Should Be Skipped",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 2,
|
||||
inputSerial: 1,
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: nameServers,
|
||||
},
|
||||
},
|
||||
},
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid NS Group Nameservers list Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: nameServers,
|
||||
},
|
||||
},
|
||||
},
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid Custom Zone Records list Should Skip",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: nameServers,
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
||||
domain: ".",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityDefault,
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "Empty Config Should Succeed and Clean Maps",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{ServiceEnable: true},
|
||||
expectedUpstreamMap: make(registeredHandlerMap),
|
||||
expectedLocalQs: []dns.Question{},
|
||||
},
|
||||
{
|
||||
name: "Disabled Service Should clean map",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{ServiceEnable: false},
|
||||
expectedUpstreamMap: make(registeredHandlerMap),
|
||||
expectedLocalQs: []dns.Question{},
|
||||
},
|
||||
}
|
||||
|
||||
for n, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
privKey, _ := wgtypes.GenerateKey()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: fmt.Sprintf("utun230%d", n),
|
||||
Address: wgaddr.MustParseWGAddress(fmt.Sprintf("100.66.100.%d/32", n+1)),
|
||||
WGPort: 33100,
|
||||
WGPrivKey: privKey.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
|
||||
wgIface, err := iface.NewWGIFace(opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = wgIface.Create()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
err = wgIface.Close()
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
}
|
||||
}()
|
||||
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
|
||||
WgInterface: wgIface,
|
||||
CustomAddress: "",
|
||||
StatusRecorder: peer.NewRecorder("mgm"),
|
||||
StateManager: nil,
|
||||
DisableSys: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
err = dnsServer.hostManager.restoreHostDNS()
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
}
|
||||
}()
|
||||
|
||||
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
||||
dnsServer.localResolver.Update(testCase.initLocalZones)
|
||||
dnsServer.updateSerial = testCase.initSerial
|
||||
|
||||
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
||||
if err != nil {
|
||||
if testCase.shouldFail {
|
||||
return
|
||||
}
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
}
|
||||
|
||||
if len(dnsServer.dnsMuxMap) != len(testCase.expectedUpstreamMap) {
|
||||
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxMap))
|
||||
}
|
||||
|
||||
for key := range testCase.expectedUpstreamMap {
|
||||
_, found := dnsServer.dnsMuxMap[key]
|
||||
if !found {
|
||||
t.Fatalf("update upstream failed, key %s was not found in the dnsMuxMap: %#v", key, dnsServer.dnsMuxMap)
|
||||
}
|
||||
}
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
for _, q := range testCase.expectedLocalQs {
|
||||
dnsServer.localResolver.ServeDNS(responseWriter, &dns.Msg{
|
||||
Question: []dns.Question{q},
|
||||
})
|
||||
}
|
||||
|
||||
if len(testCase.expectedLocalQs) > 0 {
|
||||
assert.NotNil(t, responseMSG, "response message should not be nil")
|
||||
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "response code should be success")
|
||||
assert.NotEmpty(t, responseMSG.Answer, "response message should have answers")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
||||
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
|
||||
if err != nil {
|
||||
t.Errorf("create stdnet: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
privKey, _ := wgtypes.GeneratePrivateKey()
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: "utun2301",
|
||||
Address: wgaddr.MustParseWGAddress("100.66.100.1/32"),
|
||||
WGPort: 33100,
|
||||
WGPrivKey: privKey.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
wgIface, err := iface.NewWGIFace(opts)
|
||||
if err != nil {
|
||||
t.Errorf("build interface wireguard: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = wgIface.Create()
|
||||
if err != nil {
|
||||
t.Errorf("create and init wireguard interface: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err = wgIface.Close(); err != nil {
|
||||
t.Logf("close wireguard interface: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
|
||||
if err := wgIface.SetFilter(packetfilter); err != nil {
|
||||
t.Errorf("set packet filter: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
|
||||
WgInterface: wgIface,
|
||||
CustomAddress: "",
|
||||
StatusRecorder: peer.NewRecorder("mgm"),
|
||||
StateManager: nil,
|
||||
DisableSys: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("create DNS server: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("run DNS server: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err = dnsServer.hostManager.restoreHostDNS(); err != nil {
|
||||
t.Logf("restore DNS settings on the host: %v", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
dnsServer.dnsMuxMap = registeredHandlerMap{
|
||||
"id1": handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: &local.Resolver{},
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
}
|
||||
dnsServer.localResolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}})
|
||||
dnsServer.updateSerial = 0
|
||||
|
||||
nameServers := []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.4.4"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
}
|
||||
|
||||
update := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"netbird.io"},
|
||||
NameServers: nameServers,
|
||||
},
|
||||
{
|
||||
NameServers: nameServers,
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Start the server with regular configuration
|
||||
if err := dnsServer.UpdateDNSServer(1, update); err != nil {
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
update2 := update
|
||||
update2.ServiceEnable = false
|
||||
// Disable the server, stop the listener
|
||||
if err := dnsServer.UpdateDNSServer(2, update2); err != nil {
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
update3 := update2
|
||||
update3.NameServerGroups = update3.NameServerGroups[:1]
|
||||
// But service still get updates and we checking that we handle
|
||||
// internal state in the right way
|
||||
if err := dnsServer.UpdateDNSServer(3, update3); err != nil {
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -23,7 +22,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/local"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
@@ -104,481 +102,6 @@ func init() {
|
||||
formatter.SetTextFormatter(log.StandardLogger())
|
||||
}
|
||||
|
||||
func generateDummyHandler(d string, servers []nbdns.NameServer) *upstreamResolverBase {
|
||||
var srvs []netip.AddrPort
|
||||
for _, srv := range servers {
|
||||
srvs = append(srvs, srv.AddrPort())
|
||||
}
|
||||
u := &upstreamResolverBase{
|
||||
domain: domain.Domain(d),
|
||||
cancel: func() {},
|
||||
}
|
||||
u.addRace(srvs)
|
||||
return u
|
||||
}
|
||||
|
||||
func TestUpdateDNSServer(t *testing.T) {
|
||||
|
||||
nameServers := []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.4.4"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
}
|
||||
|
||||
dummyHandler := local.NewResolver()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
initUpstreamMap registeredHandlerMap
|
||||
initLocalZones []nbdns.CustomZone
|
||||
initSerial uint64
|
||||
inputSerial uint64
|
||||
inputUpdate nbdns.Config
|
||||
shouldFail bool
|
||||
expectedUpstreamMap registeredHandlerMap
|
||||
expectedLocalQs []dns.Question
|
||||
}{
|
||||
{
|
||||
name: "Initial Config Should Succeed",
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"netbird.io"},
|
||||
NameServers: nameServers,
|
||||
},
|
||||
{
|
||||
NameServers: nameServers,
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
||||
domain: "netbird.io",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
dummyHandler.ID(): handlerWrapper{
|
||||
domain: "netbird.cloud",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityLocal,
|
||||
},
|
||||
generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
||||
domain: nbdns.RootZone,
|
||||
handler: dummyHandler,
|
||||
priority: PriorityDefault,
|
||||
},
|
||||
},
|
||||
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
|
||||
},
|
||||
{
|
||||
name: "New Config Should Succeed",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: "netbird.cloud",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"netbird.io"},
|
||||
NameServers: nameServers,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
||||
domain: "netbird.io",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
"local-resolver": handlerWrapper{
|
||||
domain: "netbird.cloud",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityLocal,
|
||||
},
|
||||
},
|
||||
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
|
||||
},
|
||||
{
|
||||
name: "Smaller Config Serial Should Be Skipped",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 2,
|
||||
inputSerial: 1,
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: nameServers,
|
||||
},
|
||||
},
|
||||
},
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid NS Group Nameservers list Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: nameServers,
|
||||
},
|
||||
},
|
||||
},
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid Custom Zone Records list Should Skip",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: nameServers,
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
||||
domain: ".",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityDefault,
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "Empty Config Should Succeed and Clean Maps",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{ServiceEnable: true},
|
||||
expectedUpstreamMap: make(registeredHandlerMap),
|
||||
expectedLocalQs: []dns.Question{},
|
||||
},
|
||||
{
|
||||
name: "Disabled Service Should clean map",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{ServiceEnable: false},
|
||||
expectedUpstreamMap: make(registeredHandlerMap),
|
||||
expectedLocalQs: []dns.Question{},
|
||||
},
|
||||
}
|
||||
|
||||
for n, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
privKey, _ := wgtypes.GenerateKey()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: fmt.Sprintf("utun230%d", n),
|
||||
Address: wgaddr.MustParseWGAddress(fmt.Sprintf("100.66.100.%d/32", n+1)),
|
||||
WGPort: 33100,
|
||||
WGPrivKey: privKey.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
|
||||
wgIface, err := iface.NewWGIFace(opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = wgIface.Create()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
err = wgIface.Close()
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
}
|
||||
}()
|
||||
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
|
||||
WgInterface: wgIface,
|
||||
CustomAddress: "",
|
||||
StatusRecorder: peer.NewRecorder("mgm"),
|
||||
StateManager: nil,
|
||||
DisableSys: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
err = dnsServer.hostManager.restoreHostDNS()
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
}
|
||||
}()
|
||||
|
||||
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
||||
dnsServer.localResolver.Update(testCase.initLocalZones)
|
||||
dnsServer.updateSerial = testCase.initSerial
|
||||
|
||||
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
||||
if err != nil {
|
||||
if testCase.shouldFail {
|
||||
return
|
||||
}
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
}
|
||||
|
||||
if len(dnsServer.dnsMuxMap) != len(testCase.expectedUpstreamMap) {
|
||||
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxMap))
|
||||
}
|
||||
|
||||
for key := range testCase.expectedUpstreamMap {
|
||||
_, found := dnsServer.dnsMuxMap[key]
|
||||
if !found {
|
||||
t.Fatalf("update upstream failed, key %s was not found in the dnsMuxMap: %#v", key, dnsServer.dnsMuxMap)
|
||||
}
|
||||
}
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
for _, q := range testCase.expectedLocalQs {
|
||||
dnsServer.localResolver.ServeDNS(responseWriter, &dns.Msg{
|
||||
Question: []dns.Question{q},
|
||||
})
|
||||
}
|
||||
|
||||
if len(testCase.expectedLocalQs) > 0 {
|
||||
assert.NotNil(t, responseMSG, "response message should not be nil")
|
||||
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "response code should be success")
|
||||
assert.NotEmpty(t, responseMSG.Answer, "response message should have answers")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
||||
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
|
||||
if err != nil {
|
||||
t.Errorf("create stdnet: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
privKey, _ := wgtypes.GeneratePrivateKey()
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: "utun2301",
|
||||
Address: wgaddr.MustParseWGAddress("100.66.100.1/32"),
|
||||
WGPort: 33100,
|
||||
WGPrivKey: privKey.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
wgIface, err := iface.NewWGIFace(opts)
|
||||
if err != nil {
|
||||
t.Errorf("build interface wireguard: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = wgIface.Create()
|
||||
if err != nil {
|
||||
t.Errorf("create and init wireguard interface: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err = wgIface.Close(); err != nil {
|
||||
t.Logf("close wireguard interface: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
|
||||
if err := wgIface.SetFilter(packetfilter); err != nil {
|
||||
t.Errorf("set packet filter: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
|
||||
WgInterface: wgIface,
|
||||
CustomAddress: "",
|
||||
StatusRecorder: peer.NewRecorder("mgm"),
|
||||
StateManager: nil,
|
||||
DisableSys: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("create DNS server: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("run DNS server: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err = dnsServer.hostManager.restoreHostDNS(); err != nil {
|
||||
t.Logf("restore DNS settings on the host: %v", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
dnsServer.dnsMuxMap = registeredHandlerMap{
|
||||
"id1": handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: &local.Resolver{},
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
}
|
||||
dnsServer.localResolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}})
|
||||
dnsServer.updateSerial = 0
|
||||
|
||||
nameServers := []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.4.4"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
}
|
||||
|
||||
update := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"netbird.io"},
|
||||
NameServers: nameServers,
|
||||
},
|
||||
{
|
||||
NameServers: nameServers,
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Start the server with regular configuration
|
||||
if err := dnsServer.UpdateDNSServer(1, update); err != nil {
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
update2 := update
|
||||
update2.ServiceEnable = false
|
||||
// Disable the server, stop the listener
|
||||
if err := dnsServer.UpdateDNSServer(2, update2); err != nil {
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
update3 := update2
|
||||
update3.NameServerGroups = update3.NameServerGroups[:1]
|
||||
// But service still get updates and we checking that we handle
|
||||
// internal state in the right way
|
||||
if err := dnsServer.UpdateDNSServer(3, update3); err != nil {
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSServerStartStop(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
|
||||
@@ -880,62 +880,25 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
|
||||
}
|
||||
|
||||
if update.GetNetbirdConfig() != nil {
|
||||
wCfg := update.GetNetbirdConfig()
|
||||
err := e.updateTURNs(wCfg.GetTurns())
|
||||
if err != nil {
|
||||
return fmt.Errorf("update TURNs: %w", err)
|
||||
}
|
||||
if err := e.updateNetbirdConfig(update.GetNetbirdConfig()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = e.updateSTUNs(wCfg.GetStuns())
|
||||
if err != nil {
|
||||
return fmt.Errorf("update STUNs: %w", err)
|
||||
}
|
||||
|
||||
var stunTurn []*stun.URI
|
||||
stunTurn = append(stunTurn, e.STUNs...)
|
||||
stunTurn = append(stunTurn, e.TURNs...)
|
||||
e.stunTurn.Store(stunTurn)
|
||||
|
||||
err = e.handleRelayUpdate(wCfg.GetRelay())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = e.handleFlowUpdate(wCfg.GetFlow())
|
||||
if err != nil {
|
||||
return fmt.Errorf("handle the flow configuration: %w", err)
|
||||
}
|
||||
|
||||
if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil {
|
||||
log.Warnf("Failed to update DNS server config: %v", err)
|
||||
}
|
||||
|
||||
// todo update signal
|
||||
// Posture checks are bound to the network map presence:
|
||||
// NetworkMap != nil, checks present -> apply the received checks
|
||||
// NetworkMap != nil, checks nil -> posture checks were removed, clear them
|
||||
// NetworkMap == nil -> config-only update (e.g. relay token rotation),
|
||||
// leave the previously applied checks untouched
|
||||
nm := update.GetNetworkMap()
|
||||
if nm == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := e.updateChecksIfNew(update.Checks); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nm := update.GetNetworkMap()
|
||||
if nm == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux.
|
||||
// A non-nil syncStore is what marks persistence as enabled. Hold the lock for
|
||||
// the whole Set so the store cannot be cleared (disabled / engine close)
|
||||
// mid-call and have this write resurrect a file that was just removed.
|
||||
e.syncRespMux.RLock()
|
||||
if e.syncStore != nil {
|
||||
if err := e.syncStore.Set(update); err != nil {
|
||||
log.Errorf("failed to persist sync response: %v", err)
|
||||
} else {
|
||||
log.Debugf("sync response persisted with serial %d", nm.GetSerial())
|
||||
}
|
||||
}
|
||||
e.syncRespMux.RUnlock()
|
||||
e.persistSyncResponse(update)
|
||||
|
||||
// only apply new changes and ignore old ones
|
||||
if err := e.updateNetworkMap(nm); err != nil {
|
||||
@@ -947,6 +910,64 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateNetbirdConfig applies the management-provided NetBird configuration:
|
||||
// STUN/TURN and relay servers, flow logging and DNS settings. A nil config is a no-op,
|
||||
// which is the case for sync updates carrying only a network map.
|
||||
func (e *Engine) updateNetbirdConfig(wCfg *mgmProto.NetbirdConfig) error {
|
||||
if wCfg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := e.updateTURNs(wCfg.GetTurns()); err != nil {
|
||||
return fmt.Errorf("update TURNs: %w", err)
|
||||
}
|
||||
|
||||
if err := e.updateSTUNs(wCfg.GetStuns()); err != nil {
|
||||
return fmt.Errorf("update STUNs: %w", err)
|
||||
}
|
||||
|
||||
var stunTurn []*stun.URI
|
||||
stunTurn = append(stunTurn, e.STUNs...)
|
||||
stunTurn = append(stunTurn, e.TURNs...)
|
||||
e.stunTurn.Store(stunTurn)
|
||||
|
||||
if err := e.handleRelayUpdate(wCfg.GetRelay()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := e.handleFlowUpdate(wCfg.GetFlow()); err != nil {
|
||||
return fmt.Errorf("handle the flow configuration: %w", err)
|
||||
}
|
||||
|
||||
if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil {
|
||||
log.Warnf("Failed to update DNS server config: %v", err)
|
||||
}
|
||||
|
||||
// todo update signal
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// persistSyncResponse stores the full sync response so it can be restored on the next
|
||||
// startup. Persistence is enabled only when syncStore is set. The dedicated syncRespMux
|
||||
// (not syncMsgMux) is held for the whole Set so the store cannot be cleared (disabled /
|
||||
// engine close) mid-call and have this write resurrect a file that was just removed.
|
||||
func (e *Engine) persistSyncResponse(update *mgmProto.SyncResponse) {
|
||||
e.syncRespMux.RLock()
|
||||
defer e.syncRespMux.RUnlock()
|
||||
|
||||
if e.syncStore == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := e.syncStore.Set(update); err != nil {
|
||||
log.Errorf("failed to persist sync response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("sync response persisted with serial %d", update.GetNetworkMap().GetSerial())
|
||||
}
|
||||
|
||||
func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
|
||||
if update != nil {
|
||||
// when we receive token we expect valid address list too
|
||||
|
||||
565
client/internal/engine_privileged_test.go
Normal file
565
client/internal/engine_privileged_test.go
Normal file
@@ -0,0 +1,565 @@
|
||||
//go:build privileged
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
mgmt "github.com/netbirdio/netbird/shared/management/client"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
func TestEngine_SSH(t *testing.T) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(
|
||||
ctx, cancel,
|
||||
&EngineConfig{
|
||||
WgIfaceName: "utun101",
|
||||
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
ServerSSHAllowed: true,
|
||||
MTU: iface.DefaultMTU,
|
||||
SSHKey: sshKey,
|
||||
},
|
||||
EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{},
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
},
|
||||
MobileDependency{},
|
||||
)
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
}
|
||||
|
||||
err = engine.Start(nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
err := engine.Stop()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
peerWithSSH := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
AllowedIps: []string{"100.64.0.21/24"},
|
||||
SshConfig: &mgmtProto.SSHConfig{
|
||||
SshPubKey: []byte("ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ"),
|
||||
},
|
||||
}
|
||||
|
||||
// SSH server is not enabled so SSH config of a remote peer should be ignored
|
||||
networkMap := &mgmtProto.NetworkMap{
|
||||
Serial: 6,
|
||||
PeerConfig: nil,
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, engine.sshServer)
|
||||
|
||||
// SSH server is enabled, therefore SSH config should be applied
|
||||
networkMap = &mgmtProto.NetworkMap{
|
||||
Serial: 7,
|
||||
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
||||
SshConfig: &mgmtProto.SSHConfig{
|
||||
SshEnabled: true,
|
||||
JwtConfig: &mgmtProto.JWTConfig{
|
||||
Issuer: "test-issuer",
|
||||
Audience: "test-audience",
|
||||
KeysLocation: "test-keys",
|
||||
MaxTokenAge: 3600,
|
||||
},
|
||||
}},
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
assert.NotNil(t, engine.sshServer)
|
||||
|
||||
// now remove peer
|
||||
networkMap = &mgmtProto.NetworkMap{
|
||||
Serial: 8,
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
// time.Sleep(250 * time.Millisecond)
|
||||
assert.NotNil(t, engine.sshServer)
|
||||
|
||||
// now disable SSH server
|
||||
networkMap = &mgmtProto.NetworkMap{
|
||||
Serial: 9,
|
||||
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
||||
SshConfig: &mgmtProto.SSHConfig{SshEnabled: false}},
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, engine.sshServer)
|
||||
}
|
||||
|
||||
func TestEngine_Sync(t *testing.T) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// feed updates to Engine via mocked Management client
|
||||
updates := make(chan *mgmtProto.SyncResponse)
|
||||
defer close(updates)
|
||||
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
||||
for msg := range updates {
|
||||
err := msgHandler(msg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(ctx, cancel, &EngineConfig{
|
||||
WgIfaceName: "utun103",
|
||||
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{SyncFunc: syncFunc},
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
}, MobileDependency{})
|
||||
engine.ctx = ctx
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := engine.Stop()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
err = engine.Start(nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
peer1 := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
AllowedIps: []string{"100.64.0.10/24"},
|
||||
}
|
||||
peer2 := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "LLHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
|
||||
AllowedIps: []string{"100.64.0.11/24"},
|
||||
}
|
||||
peer3 := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "GGHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
|
||||
AllowedIps: []string{"100.64.0.12/24"},
|
||||
}
|
||||
// 1st update with just 1 peer and serial larger than the current serial of the engine => apply update
|
||||
updates <- &mgmtProto.SyncResponse{
|
||||
NetworkMap: &mgmtProto.NetworkMap{
|
||||
Serial: 10,
|
||||
PeerConfig: nil,
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1, peer2, peer3},
|
||||
RemotePeersIsEmpty: false,
|
||||
},
|
||||
}
|
||||
|
||||
timeout := time.After(time.Second * 2)
|
||||
for {
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Fatalf("timeout while waiting for test to finish")
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if getPeers(engine) == 3 && engine.networkSerial == 10 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngine_MultiplePeers(t *testing.T) {
|
||||
// log.SetLevel(log.DebugLevel)
|
||||
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
defer cancel()
|
||||
|
||||
sigServer, signalAddr, err := startSignal(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
defer sigServer.Stop()
|
||||
mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sql")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
defer mgmtServer.GracefulStop()
|
||||
|
||||
setupKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||
|
||||
mu := sync.Mutex{}
|
||||
engines := []*Engine{}
|
||||
numPeers := 10
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(numPeers)
|
||||
// create and start peers
|
||||
for i := 0; i < numPeers; i++ {
|
||||
j := i
|
||||
go func() {
|
||||
engine, err := createEngine(ctx, cancel, setupKey, j, mgmtAddr, signalAddr)
|
||||
if err != nil {
|
||||
wg.Done()
|
||||
t.Errorf("unable to create the engine for peer %d with error %v", j, err)
|
||||
return
|
||||
}
|
||||
engine.dnsServer = &dns.MockServer{}
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
guid := fmt.Sprintf("{%s}", uuid.New().String())
|
||||
device.CustomWindowsGUIDString = strings.ToLower(guid)
|
||||
err = engine.Start(nil, nil)
|
||||
if err != nil {
|
||||
t.Errorf("unable to start engine for peer %d with error %v", j, err)
|
||||
wg.Done()
|
||||
return
|
||||
}
|
||||
engines = append(engines, engine)
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
// wait until all have been created and started
|
||||
wg.Wait()
|
||||
if len(engines) != numPeers {
|
||||
t.Fatal("not all peers was started")
|
||||
}
|
||||
// check whether all the peer have expected peers connected
|
||||
|
||||
expectedConnected := numPeers * (numPeers - 1)
|
||||
|
||||
// adjust according to timeouts
|
||||
timeout := 50 * time.Second
|
||||
timeoutChan := time.After(timeout)
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-timeoutChan:
|
||||
t.Fatalf("waiting for expected connections timeout after %s", timeout.String())
|
||||
break loop
|
||||
case <-ticker.C:
|
||||
totalConnected := 0
|
||||
for _, engine := range engines {
|
||||
totalConnected += getConnectedPeers(engine)
|
||||
}
|
||||
if totalConnected == expectedConnected {
|
||||
log.Infof("total connected=%d", totalConnected)
|
||||
break loop
|
||||
}
|
||||
log.Infof("total connected=%d", totalConnected)
|
||||
}
|
||||
}
|
||||
// cleanup test
|
||||
for n, peerEngine := range engines {
|
||||
t.Logf("stopping peer with interface %s from multipeer test, loopIndex %d", peerEngine.wgInterface.Name(), n)
|
||||
errStop := peerEngine.mgmClient.Close()
|
||||
if errStop != nil {
|
||||
log.Infoln("got error trying to close management clients from engine: ", errStop)
|
||||
}
|
||||
errStop = peerEngine.Stop()
|
||||
if errStop != nil {
|
||||
log.Infoln("got error trying to close testing peers engine: ", errStop)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
kaep = keepalive.EnforcementPolicy{
|
||||
MinTime: 15 * time.Second,
|
||||
PermitWithoutStream: true,
|
||||
}
|
||||
|
||||
kasp = keepalive.ServerParameters{
|
||||
MaxConnectionIdle: 15 * time.Second,
|
||||
MaxConnectionAgeGrace: 5 * time.Second,
|
||||
Time: 5 * time.Second,
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
)
|
||||
|
||||
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mgmtClient, err := mgmt.NewClient(ctx, mgmtAddr, key, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
signalClient, err := signal.NewClient(ctx, signalAddr, key, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
info := system.GetInfo(ctx)
|
||||
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ifaceName string
|
||||
if runtime.GOOS == "darwin" {
|
||||
ifaceName = fmt.Sprintf("utun1%d", i)
|
||||
} else {
|
||||
ifaceName = fmt.Sprintf("wt%d", i)
|
||||
}
|
||||
|
||||
wgPort := 33100 + i
|
||||
conf := &EngineConfig{
|
||||
WgIfaceName: ifaceName,
|
||||
WgAddr: wgaddr.MustParseWGAddress(resp.PeerConfig.Address),
|
||||
WgPrivateKey: key,
|
||||
WgPort: wgPort,
|
||||
MTU: iface.DefaultMTU,
|
||||
}
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
e, err := NewEngine(ctx, cancel, conf, EngineServices{
|
||||
SignalClient: signalClient,
|
||||
MgmClient: mgmtClient,
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
}, MobileDependency{}), nil
|
||||
e.ctx = ctx
|
||||
return e, err
|
||||
}
|
||||
|
||||
func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
log.Fatalf("failed to listen: %v", err)
|
||||
}
|
||||
|
||||
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
|
||||
require.NoError(t, err)
|
||||
proto.RegisterSignalExchangeServer(s, srv)
|
||||
|
||||
go func() {
|
||||
if err = s.Serve(lis); err != nil {
|
||||
log.Fatalf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
config := &config.Config{
|
||||
Stuns: []*config.Host{},
|
||||
TURNConfig: &config.TURNConfig{},
|
||||
Relay: &config.Relay{
|
||||
Addresses: []string{"127.0.0.1:1234"},
|
||||
CredentialsTTL: util.Duration{Duration: time.Hour},
|
||||
Secret: "222222222222222222",
|
||||
},
|
||||
Signal: &config.Host{
|
||||
Proto: "http",
|
||||
URI: "localhost:10000",
|
||||
},
|
||||
Datadir: dataDir,
|
||||
HttpConfig: nil,
|
||||
}
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||
|
||||
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
settingsMockManager.EXPECT().
|
||||
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
settingsMockManager.EXPECT().
|
||||
GetExtraSettings(gomock.Any(), gomock.Any()).
|
||||
Return(&types.ExtraSettings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
||||
go func() {
|
||||
if err = s.Serve(lis); err != nil {
|
||||
log.Fatalf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
// getConnectedPeers returns a connection Status or nil if peer connection wasn't found
|
||||
func getConnectedPeers(e *Engine) int {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
i := 0
|
||||
for _, id := range e.peerStore.PeersPubKey() {
|
||||
conn, _ := e.peerStore.PeerConn(id)
|
||||
if conn.IsConnected() {
|
||||
i++
|
||||
}
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
func getPeers(e *Engine) int {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
return len(e.peerStore.PeersPubKey())
|
||||
}
|
||||
@@ -6,37 +6,18 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
@@ -50,18 +31,7 @@ import (
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/monotime"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
mgmt "github.com/netbirdio/netbird/shared/management/client"
|
||||
@@ -69,25 +39,9 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var (
|
||||
kaep = keepalive.EnforcementPolicy{
|
||||
MinTime: 15 * time.Second,
|
||||
PermitWithoutStream: true,
|
||||
}
|
||||
|
||||
kasp = keepalive.ServerParameters{
|
||||
MaxConnectionIdle: 15 * time.Second,
|
||||
MaxConnectionAgeGrace: 5 * time.Second,
|
||||
Time: 5 * time.Second,
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
)
|
||||
|
||||
type MockWGIface struct {
|
||||
CreateFunc func() error
|
||||
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
|
||||
@@ -234,129 +188,6 @@ func TestMain(m *testing.M) {
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func TestEngine_SSH(t *testing.T) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(
|
||||
ctx, cancel,
|
||||
&EngineConfig{
|
||||
WgIfaceName: "utun101",
|
||||
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
ServerSSHAllowed: true,
|
||||
MTU: iface.DefaultMTU,
|
||||
SSHKey: sshKey,
|
||||
},
|
||||
EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{},
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
},
|
||||
MobileDependency{},
|
||||
)
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
}
|
||||
|
||||
err = engine.Start(nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
err := engine.Stop()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
peerWithSSH := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
AllowedIps: []string{"100.64.0.21/24"},
|
||||
SshConfig: &mgmtProto.SSHConfig{
|
||||
SshPubKey: []byte("ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ"),
|
||||
},
|
||||
}
|
||||
|
||||
// SSH server is not enabled so SSH config of a remote peer should be ignored
|
||||
networkMap := &mgmtProto.NetworkMap{
|
||||
Serial: 6,
|
||||
PeerConfig: nil,
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, engine.sshServer)
|
||||
|
||||
// SSH server is enabled, therefore SSH config should be applied
|
||||
networkMap = &mgmtProto.NetworkMap{
|
||||
Serial: 7,
|
||||
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
||||
SshConfig: &mgmtProto.SSHConfig{
|
||||
SshEnabled: true,
|
||||
JwtConfig: &mgmtProto.JWTConfig{
|
||||
Issuer: "test-issuer",
|
||||
Audience: "test-audience",
|
||||
KeysLocation: "test-keys",
|
||||
MaxTokenAge: 3600,
|
||||
},
|
||||
}},
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
assert.NotNil(t, engine.sshServer)
|
||||
|
||||
// now remove peer
|
||||
networkMap = &mgmtProto.NetworkMap{
|
||||
Serial: 8,
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
// time.Sleep(250 * time.Millisecond)
|
||||
assert.NotNil(t, engine.sshServer)
|
||||
|
||||
// now disable SSH server
|
||||
networkMap = &mgmtProto.NetworkMap{
|
||||
Serial: 9,
|
||||
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
||||
SshConfig: &mgmtProto.SSHConfig{SshEnabled: false}},
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, engine.sshServer)
|
||||
}
|
||||
|
||||
func TestEngine_SSHUpdateLogic(t *testing.T) {
|
||||
// Test that SSH server start/stop logic works based on config
|
||||
engine := &Engine{
|
||||
@@ -631,97 +462,6 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngine_Sync(t *testing.T) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// feed updates to Engine via mocked Management client
|
||||
updates := make(chan *mgmtProto.SyncResponse)
|
||||
defer close(updates)
|
||||
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
||||
for msg := range updates {
|
||||
err := msgHandler(msg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(ctx, cancel, &EngineConfig{
|
||||
WgIfaceName: "utun103",
|
||||
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{SyncFunc: syncFunc},
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
}, MobileDependency{})
|
||||
engine.ctx = ctx
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := engine.Stop()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
err = engine.Start(nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
peer1 := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
AllowedIps: []string{"100.64.0.10/24"},
|
||||
}
|
||||
peer2 := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "LLHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
|
||||
AllowedIps: []string{"100.64.0.11/24"},
|
||||
}
|
||||
peer3 := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "GGHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
|
||||
AllowedIps: []string{"100.64.0.12/24"},
|
||||
}
|
||||
// 1st update with just 1 peer and serial larger than the current serial of the engine => apply update
|
||||
updates <- &mgmtProto.SyncResponse{
|
||||
NetworkMap: &mgmtProto.NetworkMap{
|
||||
Serial: 10,
|
||||
PeerConfig: nil,
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1, peer2, peer3},
|
||||
RemotePeersIsEmpty: false,
|
||||
},
|
||||
}
|
||||
|
||||
timeout := time.After(time.Second * 2)
|
||||
for {
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Fatalf("timeout while waiting for test to finish")
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if getPeers(engine) == 3 && engine.networkSerial == 10 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -1105,104 +845,6 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngine_MultiplePeers(t *testing.T) {
|
||||
// log.SetLevel(log.DebugLevel)
|
||||
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
defer cancel()
|
||||
|
||||
sigServer, signalAddr, err := startSignal(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
defer sigServer.Stop()
|
||||
mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sql")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
defer mgmtServer.GracefulStop()
|
||||
|
||||
setupKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||
|
||||
mu := sync.Mutex{}
|
||||
engines := []*Engine{}
|
||||
numPeers := 10
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(numPeers)
|
||||
// create and start peers
|
||||
for i := 0; i < numPeers; i++ {
|
||||
j := i
|
||||
go func() {
|
||||
engine, err := createEngine(ctx, cancel, setupKey, j, mgmtAddr, signalAddr)
|
||||
if err != nil {
|
||||
wg.Done()
|
||||
t.Errorf("unable to create the engine for peer %d with error %v", j, err)
|
||||
return
|
||||
}
|
||||
engine.dnsServer = &dns.MockServer{}
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
guid := fmt.Sprintf("{%s}", uuid.New().String())
|
||||
device.CustomWindowsGUIDString = strings.ToLower(guid)
|
||||
err = engine.Start(nil, nil)
|
||||
if err != nil {
|
||||
t.Errorf("unable to start engine for peer %d with error %v", j, err)
|
||||
wg.Done()
|
||||
return
|
||||
}
|
||||
engines = append(engines, engine)
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
// wait until all have been created and started
|
||||
wg.Wait()
|
||||
if len(engines) != numPeers {
|
||||
t.Fatal("not all peers was started")
|
||||
}
|
||||
// check whether all the peer have expected peers connected
|
||||
|
||||
expectedConnected := numPeers * (numPeers - 1)
|
||||
|
||||
// adjust according to timeouts
|
||||
timeout := 50 * time.Second
|
||||
timeoutChan := time.After(timeout)
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-timeoutChan:
|
||||
t.Fatalf("waiting for expected connections timeout after %s", timeout.String())
|
||||
break loop
|
||||
case <-ticker.C:
|
||||
totalConnected := 0
|
||||
for _, engine := range engines {
|
||||
totalConnected += getConnectedPeers(engine)
|
||||
}
|
||||
if totalConnected == expectedConnected {
|
||||
log.Infof("total connected=%d", totalConnected)
|
||||
break loop
|
||||
}
|
||||
log.Infof("total connected=%d", totalConnected)
|
||||
}
|
||||
}
|
||||
// cleanup test
|
||||
for n, peerEngine := range engines {
|
||||
t.Logf("stopping peer with interface %s from multipeer test, loopIndex %d", peerEngine.wgInterface.Name(), n)
|
||||
errStop := peerEngine.mgmClient.Close()
|
||||
if errStop != nil {
|
||||
log.Infoln("got error trying to close management clients from engine: ", errStop)
|
||||
}
|
||||
errStop = peerEngine.Stop()
|
||||
if errStop != nil {
|
||||
log.Infoln("got error trying to close testing peers engine: ", errStop)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ParseNATExternalIPMappings(t *testing.T) {
|
||||
ifaceList, err := net.Interfaces()
|
||||
if err != nil {
|
||||
@@ -1526,187 +1168,6 @@ func TestCompareNetIPLists(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mgmtClient, err := mgmt.NewClient(ctx, mgmtAddr, key, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
signalClient, err := signal.NewClient(ctx, signalAddr, key, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
info := system.GetInfo(ctx)
|
||||
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ifaceName string
|
||||
if runtime.GOOS == "darwin" {
|
||||
ifaceName = fmt.Sprintf("utun1%d", i)
|
||||
} else {
|
||||
ifaceName = fmt.Sprintf("wt%d", i)
|
||||
}
|
||||
|
||||
wgPort := 33100 + i
|
||||
conf := &EngineConfig{
|
||||
WgIfaceName: ifaceName,
|
||||
WgAddr: wgaddr.MustParseWGAddress(resp.PeerConfig.Address),
|
||||
WgPrivateKey: key,
|
||||
WgPort: wgPort,
|
||||
MTU: iface.DefaultMTU,
|
||||
}
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
e, err := NewEngine(ctx, cancel, conf, EngineServices{
|
||||
SignalClient: signalClient,
|
||||
MgmClient: mgmtClient,
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
}, MobileDependency{}), nil
|
||||
e.ctx = ctx
|
||||
return e, err
|
||||
}
|
||||
|
||||
func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
log.Fatalf("failed to listen: %v", err)
|
||||
}
|
||||
|
||||
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
|
||||
require.NoError(t, err)
|
||||
proto.RegisterSignalExchangeServer(s, srv)
|
||||
|
||||
go func() {
|
||||
if err = s.Serve(lis); err != nil {
|
||||
log.Fatalf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
config := &config.Config{
|
||||
Stuns: []*config.Host{},
|
||||
TURNConfig: &config.TURNConfig{},
|
||||
Relay: &config.Relay{
|
||||
Addresses: []string{"127.0.0.1:1234"},
|
||||
CredentialsTTL: util.Duration{Duration: time.Hour},
|
||||
Secret: "222222222222222222",
|
||||
},
|
||||
Signal: &config.Host{
|
||||
Proto: "http",
|
||||
URI: "localhost:10000",
|
||||
},
|
||||
Datadir: dataDir,
|
||||
HttpConfig: nil,
|
||||
}
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||
|
||||
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
settingsMockManager.EXPECT().
|
||||
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
settingsMockManager.EXPECT().
|
||||
GetExtraSettings(gomock.Any(), gomock.Any()).
|
||||
Return(&types.ExtraSettings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
||||
go func() {
|
||||
if err = s.Serve(lis); err != nil {
|
||||
log.Fatalf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
// getConnectedPeers returns a connection Status or nil if peer connection wasn't found
|
||||
func getConnectedPeers(e *Engine) int {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
i := 0
|
||||
for _, id := range e.peerStore.PeersPubKey() {
|
||||
conn, _ := e.peerStore.PeerConn(id)
|
||||
if conn.IsConnected() {
|
||||
i++
|
||||
}
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
func getPeers(e *Engine) int {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
return len(e.peerStore.PeersPubKey())
|
||||
}
|
||||
|
||||
func mustEncodePrefix(t *testing.T, p netip.Prefix) []byte {
|
||||
t.Helper()
|
||||
b, err := netiputil.EncodePrefix(p)
|
||||
|
||||
@@ -26,7 +26,6 @@ type connStatusInputs struct {
|
||||
iceInProgress bool // a negotiation is currently in flight
|
||||
}
|
||||
|
||||
|
||||
// ConnStatus describe the status of a peer's connection
|
||||
type ConnStatus int32
|
||||
|
||||
|
||||
@@ -193,6 +193,7 @@ func (s *StatusChangeSubscription) Events() chan map[string]RouterState {
|
||||
type Status struct {
|
||||
mux sync.RWMutex
|
||||
peers map[string]State
|
||||
ipToKey map[string]string
|
||||
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
|
||||
signalState bool
|
||||
signalError error
|
||||
@@ -231,6 +232,7 @@ type Status struct {
|
||||
func NewRecorder(mgmAddress string) *Status {
|
||||
return &Status{
|
||||
peers: make(map[string]State),
|
||||
ipToKey: make(map[string]string),
|
||||
changeNotify: make(map[string]map[string]*StatusChangeSubscription),
|
||||
eventStreams: make(map[string]chan *proto.SystemEvent),
|
||||
eventQueue: NewEventQueue(eventQueueSize),
|
||||
@@ -282,6 +284,12 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string, ip string, ipv6 string)
|
||||
Mux: new(sync.RWMutex),
|
||||
}
|
||||
d.peerListChangedForNotification = true
|
||||
if ipv6 != "" {
|
||||
d.ipToKey[ipv6] = peerPubKey
|
||||
}
|
||||
if ip != "" {
|
||||
d.ipToKey[ip] = peerPubKey
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -311,28 +319,22 @@ func (d *Status) PeerByIP(ip string) (string, bool) {
|
||||
|
||||
// PeerStateByIP returns the full peer State for the given tunnel IP.
|
||||
// Matches against either the IPv4 (State.IP) or IPv6 (State.IPv6) tunnel
|
||||
// address so dual-stack peers are reachable on either family. Searches
|
||||
// both d.peers and d.offlinePeers — peers that have been moved into
|
||||
// the offline slice by ReplaceOfflinePeers are still part of the
|
||||
// account's roster and callers (DNS filter, embed.Client.IdentityForIP)
|
||||
// need to recognise them rather than treating them as unknown. Returns
|
||||
// the zero State and false when no peer matches or the input is empty.
|
||||
// address so dual-stack peers are reachable on either family. Only
|
||||
// active peers are matched; peers moved into the offline slice by
|
||||
// ReplaceOfflinePeers are intentionally treated as unknown.
|
||||
func (d *Status) PeerStateByIP(ip string) (State, bool) {
|
||||
if ip == "" {
|
||||
return State{}, false
|
||||
}
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
|
||||
for _, state := range d.peers {
|
||||
if (state.IP != "" && state.IP == ip) || (state.IPv6 != "" && state.IPv6 == ip) {
|
||||
return state, true
|
||||
}
|
||||
key, ok := d.ipToKey[ip]
|
||||
if !ok {
|
||||
return State{}, false
|
||||
}
|
||||
for _, state := range d.offlinePeers {
|
||||
if (state.IP != "" && state.IP == ip) || (state.IPv6 != "" && state.IPv6 == ip) {
|
||||
return state, true
|
||||
}
|
||||
state, ok := d.peers[key]
|
||||
if ok {
|
||||
return state, true
|
||||
}
|
||||
return State{}, false
|
||||
}
|
||||
@@ -342,12 +344,18 @@ func (d *Status) RemovePeer(peerPubKey string) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
_, ok := d.peers[peerPubKey]
|
||||
p, ok := d.peers[peerPubKey]
|
||||
if !ok {
|
||||
return errors.New("no peer with to remove")
|
||||
}
|
||||
|
||||
delete(d.peers, peerPubKey)
|
||||
if mappedKey, exists := d.ipToKey[p.IP]; exists && mappedKey == peerPubKey {
|
||||
delete(d.ipToKey, p.IP)
|
||||
}
|
||||
if mappedKey, exists := d.ipToKey[p.IPv6]; exists && mappedKey == peerPubKey {
|
||||
delete(d.ipToKey, p.IPv6)
|
||||
}
|
||||
d.peerListChangedForNotification = true
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -90,12 +90,11 @@ func TestStatus_PeerStateByIP_MatchesIPv6(t *testing.T) {
|
||||
req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key")
|
||||
}
|
||||
|
||||
// TestStatus_PeerStateByIP_MatchesOfflinePeers covers peers that have
|
||||
// been moved into the offline slice via ReplaceOfflinePeers. Callers
|
||||
// (DNS filter, embed.Client.IdentityForIP) need to treat them as known
|
||||
// rather than unknown — otherwise authentication / DNS filtering treats
|
||||
// known-but-offline peers as foreign IPs.
|
||||
func TestStatus_PeerStateByIP_MatchesOfflinePeers(t *testing.T) {
|
||||
// TestStatus_PeerStateByIP_IgnoresOfflinePeers documents that peers
|
||||
// moved into the offline slice via ReplaceOfflinePeers are intentionally
|
||||
// not resolvable by IP: only active peers can carry traffic, so callers
|
||||
// (DNS filter, embed.Client.IdentityForIP) treat them as unknown.
|
||||
func TestStatus_PeerStateByIP_IgnoresOfflinePeers(t *testing.T) {
|
||||
status := NewRecorder("https://mgm")
|
||||
req := require.New(t)
|
||||
|
||||
@@ -103,13 +102,31 @@ func TestStatus_PeerStateByIP_MatchesOfflinePeers(t *testing.T) {
|
||||
{PubKey: "pk-offline", FQDN: "offline.netbird", IP: "100.64.0.20", IPv6: "fd00::20"},
|
||||
})
|
||||
|
||||
state, ok := status.PeerStateByIP("100.64.0.20")
|
||||
req.True(ok, "offline peer must resolve by IPv4 tunnel address")
|
||||
req.Equal("pk-offline", state.PubKey, "matching state must carry the offline peer's pub key")
|
||||
_, ok := status.PeerStateByIP("100.64.0.20")
|
||||
req.False(ok, "offline peer must not resolve by IPv4 tunnel address")
|
||||
|
||||
state, ok = status.PeerStateByIP("fd00::20")
|
||||
req.True(ok, "offline peer must resolve by IPv6 tunnel address")
|
||||
req.Equal("pk-offline", state.PubKey, "IPv6 match must carry the offline peer's pub key")
|
||||
_, ok = status.PeerStateByIP("fd00::20")
|
||||
req.False(ok, "offline peer must not resolve by IPv6 tunnel address")
|
||||
}
|
||||
|
||||
// TestStatus_PeerStateByIP_RemovedPeer verifies RemovePeer drops the
|
||||
// IP index entries for both address families.
|
||||
func TestStatus_PeerStateByIP_RemovedPeer(t *testing.T) {
|
||||
status := NewRecorder("https://mgm")
|
||||
req := require.New(t)
|
||||
|
||||
req.NoError(status.AddPeer("pk-1", "peer-1.netbird", "100.64.0.10", "fd00::1"))
|
||||
|
||||
_, ok := status.PeerStateByIP("100.64.0.10")
|
||||
req.True(ok, "active peer must resolve before removal")
|
||||
|
||||
req.NoError(status.RemovePeer("pk-1"))
|
||||
|
||||
_, ok = status.PeerStateByIP("100.64.0.10")
|
||||
req.False(ok, "removed peer must not resolve by IPv4 tunnel address")
|
||||
|
||||
_, ok = status.PeerStateByIP("fd00::1")
|
||||
req.False(ok, "removed peer must not resolve by IPv6 tunnel address")
|
||||
}
|
||||
|
||||
func TestStatus_UpdatePeerFQDN(t *testing.T) {
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
@@ -57,6 +58,10 @@ var DefaultInterfaceBlacklist = []string{
|
||||
"Tailscale", "tailscale", "docker", "veth", "br-", "lo",
|
||||
}
|
||||
|
||||
// loadMDMPolicy is the package-level indirection used by apply() to read the
|
||||
// active MDM policy. Tests override this to inject a fake policy.
|
||||
var loadMDMPolicy = mdm.LoadPolicy
|
||||
|
||||
// ConfigInput carries configuration changes to the client
|
||||
type ConfigInput struct {
|
||||
ManagementURL string
|
||||
@@ -174,6 +179,23 @@ type Config struct {
|
||||
LazyConnectionEnabled bool
|
||||
|
||||
MTU uint16
|
||||
|
||||
// policy is the MDM policy that produced the currently-set values for
|
||||
// any MDM-enforced fields. Set by applyMDMPolicy at the tail of apply()
|
||||
// and reset on every apply() invocation. Never persisted to disk.
|
||||
// Callers query enforcement state via Policy() and the mdm.Policy API
|
||||
// (HasKey, ManagedKeys, IsEmpty).
|
||||
policy *mdm.Policy `json:"-"`
|
||||
}
|
||||
|
||||
// Policy returns the MDM policy applied to this Config. Returns a non-nil
|
||||
// empty Policy when MDM enforcement is inactive; callers can always invoke
|
||||
// HasKey / ManagedKeys / IsEmpty without a nil check.
|
||||
func (config *Config) Policy() *mdm.Policy {
|
||||
if config == nil || config.policy == nil {
|
||||
return mdm.NewPolicy(nil)
|
||||
}
|
||||
return config.policy
|
||||
}
|
||||
|
||||
var ConfigDirOverride string
|
||||
@@ -612,10 +634,93 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
// MDM is the last override layer: any key present in the policy
|
||||
// supersedes defaults, on-disk config, env vars and CLI input.
|
||||
config.applyMDMPolicy(loadMDMPolicy())
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
// parseURL parses and validates a service URL
|
||||
// applyMDMPolicy overlays MDM-supplied values on top of the resolved Config.
|
||||
// The provided Policy is also stored on the Config so callers can later query
|
||||
// which fields are enforced. Invalid values (e.g. malformed URLs) are logged
|
||||
// and skipped to avoid bricking the client; the field keeps its previous
|
||||
// resolved value but is still marked as managed (Policy.HasKey returns true
|
||||
// for the key, so per-field rejection of user writes still applies).
|
||||
func (config *Config) applyMDMPolicy(policy *mdm.Policy) {
|
||||
config.policy = policy
|
||||
if policy.IsEmpty() {
|
||||
return
|
||||
}
|
||||
|
||||
// Helper: log the application of a single MDM-managed key. Values for
|
||||
// keys in mdm.SecretKeys are redacted.
|
||||
logApplied := func(key string, displayValue any) {
|
||||
if _, secret := mdm.SecretKeys[key]; secret {
|
||||
log.Infof("MDM override %s = ********** (secret)", key)
|
||||
return
|
||||
}
|
||||
log.Infof("MDM override %s = %v", key, displayValue)
|
||||
}
|
||||
|
||||
if v, ok := policy.GetString(mdm.KeyManagementURL); ok {
|
||||
if u, err := parseURL("Management URL", v); err != nil {
|
||||
log.Warnf("MDM management URL %q invalid: %v; keeping previous value", v, err)
|
||||
} else {
|
||||
config.ManagementURL = u
|
||||
logApplied(mdm.KeyManagementURL, u.String())
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := policy.GetString(mdm.KeyPreSharedKey); ok {
|
||||
// Defensive: refuse the redaction mask in case it round-tripped
|
||||
// through a manifest by mistake.
|
||||
if !isPreSharedKeyHidden(&v) {
|
||||
config.PreSharedKey = v
|
||||
logApplied(mdm.KeyPreSharedKey, "")
|
||||
}
|
||||
}
|
||||
|
||||
// applyBool collapses the per-key "read + set + log" boilerplate
|
||||
// for every plain bool MDM key into a single helper. Keeps the
|
||||
// outer function's cognitive complexity below SonarCube's
|
||||
// threshold; functional behaviour is identical to the inlined
|
||||
// branches it replaces.
|
||||
applyBool := func(key string, setter func(bool)) {
|
||||
v, ok := policy.GetBool(key)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
setter(v)
|
||||
logApplied(key, v)
|
||||
}
|
||||
|
||||
applyBool(mdm.KeyAllowServerSSH, func(v bool) { bv := v; config.ServerSSHAllowed = &bv })
|
||||
applyBool(mdm.KeyDisableClientRoutes, func(v bool) { config.DisableClientRoutes = v })
|
||||
applyBool(mdm.KeyDisableServerRoutes, func(v bool) { config.DisableServerRoutes = v })
|
||||
applyBool(mdm.KeyBlockInbound, func(v bool) { config.BlockInbound = v })
|
||||
applyBool(mdm.KeyDisableAutoConnect, func(v bool) { config.DisableAutoConnect = v })
|
||||
applyBool(mdm.KeyRosenpassEnabled, func(v bool) { config.RosenpassEnabled = v })
|
||||
applyBool(mdm.KeyRosenpassPermissive, func(v bool) { config.RosenpassPermissive = v })
|
||||
|
||||
if v, ok := policy.GetInt(mdm.KeyWireguardPort); ok {
|
||||
// REG_DWORD is 32-bit; UDP port range is 1-65535. Clamp at the
|
||||
// upper bound and reject obviously-invalid values to avoid the
|
||||
// engine binding to an unusable port if the admin pushes garbage.
|
||||
if v >= 1 && v <= 65535 {
|
||||
config.WgPort = int(v)
|
||||
logApplied(mdm.KeyWireguardPort, v)
|
||||
} else {
|
||||
log.Warnf("MDM wireguard port %d out of range [1,65535]; keeping previous value", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parseURL parses and validates the URL for the named service. The URL
|
||||
// must use the http or https scheme; if no port is present, ":443" is
|
||||
// appended for https or ":80" for http. The serviceName parameter is
|
||||
// used to contextualise error messages. On success returns the parsed
|
||||
// *url.URL; on failure returns a non-nil error.
|
||||
func parseURL(serviceName, serviceURL string) (*url.URL, error) {
|
||||
parsedMgmtURL, err := url.ParseRequestURI(serviceURL)
|
||||
if err != nil {
|
||||
|
||||
152
client/internal/profilemanager/config_mdm_test.go
Normal file
152
client/internal/profilemanager/config_mdm_test.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package profilemanager
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
)
|
||||
|
||||
// withMDMPolicy temporarily overrides the package-level loadMDMPolicy hook so
|
||||
// apply() observes the supplied Policy. The original loader is restored at
|
||||
// test cleanup.
|
||||
func withMDMPolicy(t *testing.T, policy *mdm.Policy) {
|
||||
t.Helper()
|
||||
prev := loadMDMPolicy
|
||||
loadMDMPolicy = func() *mdm.Policy { return policy }
|
||||
t.Cleanup(func() { loadMDMPolicy = prev })
|
||||
}
|
||||
|
||||
func TestApply_MDMEmpty_NoEnforcement(t *testing.T) {
|
||||
withMDMPolicy(t, mdm.NewPolicy(nil))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
assert.True(t, cfg.Policy().IsEmpty(), "no MDM source ⇒ empty Policy")
|
||||
assert.False(t, cfg.Policy().HasKey(mdm.KeyManagementURL))
|
||||
assert.Empty(t, cfg.Policy().ManagedKeys())
|
||||
|
||||
// Default management URL still resolves.
|
||||
assert.Equal(t, DefaultManagementURL, cfg.ManagementURL.String())
|
||||
}
|
||||
|
||||
func TestApply_MDMOnly_OverridesDefaults(t *testing.T) {
|
||||
const mdmURL = "https://corp.mdm.example.com:443"
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: mdmURL,
|
||||
mdm.KeyDisableClientRoutes: true,
|
||||
mdm.KeyBlockInbound: true,
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
assert.Equal(t, mdmURL, cfg.ManagementURL.String())
|
||||
assert.True(t, cfg.DisableClientRoutes)
|
||||
assert.True(t, cfg.BlockInbound)
|
||||
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyManagementURL))
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyDisableClientRoutes))
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyBlockInbound))
|
||||
assert.False(t, cfg.Policy().HasKey(mdm.KeyAllowServerSSH))
|
||||
}
|
||||
|
||||
func TestApply_MDMBeatsCLIInput(t *testing.T) {
|
||||
const mdmURL = "https://mdm.example.com:443"
|
||||
const cliURL = "https://cli.example.com:443"
|
||||
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: mdmURL,
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
ManagementURL: cliURL,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
// MDM wins over CLI-supplied management URL.
|
||||
assert.Equal(t, mdmURL, cfg.ManagementURL.String())
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyManagementURL))
|
||||
}
|
||||
|
||||
func TestApply_MDMInvalidURL_KeepsPreviousValue(t *testing.T) {
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: "not-a-url",
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
// Invalid MDM URL is logged and skipped: default URL stays in place
|
||||
// to keep the client functional.
|
||||
assert.Equal(t, DefaultManagementURL, cfg.ManagementURL.String())
|
||||
|
||||
// But the key is still considered MDM-managed (admin intent is to
|
||||
// enforce, daemon rejects user writes to this field — phase-1 scaffolding
|
||||
// reflects this by keeping Policy.HasKey true even on parse failure).
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyManagementURL))
|
||||
}
|
||||
|
||||
func TestApply_MDMBoolKeysOverrideOnDiskValue(t *testing.T) {
|
||||
tmp := filepath.Join(t.TempDir(), "config.json")
|
||||
|
||||
// Seed without MDM.
|
||||
withMDMPolicy(t, mdm.NewPolicy(nil))
|
||||
_, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: tmp,
|
||||
DisableClientRoutes: boolPtr(false),
|
||||
RosenpassEnabled: boolPtr(false),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now enable MDM enforcement for these keys.
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyDisableClientRoutes: true,
|
||||
mdm.KeyRosenpassEnabled: true,
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{ConfigPath: tmp})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
assert.True(t, cfg.DisableClientRoutes, "MDM override should flip on-disk false to true")
|
||||
assert.True(t, cfg.RosenpassEnabled)
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyDisableClientRoutes))
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyRosenpassEnabled))
|
||||
}
|
||||
|
||||
func TestApply_MDMPreSharedKeyRedactionSentinelRejected(t *testing.T) {
|
||||
const maskSentinel = "**********"
|
||||
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyPreSharedKey: maskSentinel,
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
// Mask sentinel must not be persisted as the actual PSK.
|
||||
assert.NotEqual(t, maskSentinel, cfg.PreSharedKey)
|
||||
// Key still marked managed so user writes are still rejected.
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyPreSharedKey))
|
||||
}
|
||||
|
||||
func boolPtr(b bool) *bool { return &b }
|
||||
@@ -700,6 +700,13 @@ func resolveURLsToIPs(urls []string) []net.IP {
|
||||
|
||||
// updateRouteSelectorFromManagement updates the route selector based on the isSelected status from the management server
|
||||
func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HAMap) {
|
||||
// An explicit user "deselect all" must not be overridden by management auto-apply.
|
||||
// Auto-applying an exit node here would call SelectRoutes, which clears the
|
||||
// deselect-all flag and re-enables every route the user turned off.
|
||||
if m.routeSelector.IsDeselectAll() {
|
||||
return
|
||||
}
|
||||
|
||||
exitNodeInfo := m.collectExitNodeInfo(clientRoutes)
|
||||
if len(exitNodeInfo.allIDs) == 0 {
|
||||
return
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build privileged
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
|
||||
71
client/internal/routemanager/selector_management_test.go
Normal file
71
client/internal/routemanager/selector_management_test.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
func exitNodeRoutes(netID route.NetID, skipAutoApply bool) route.HAMap {
|
||||
haID := route.HAUniqueID(string(netID) + "|0.0.0.0/0")
|
||||
return route.HAMap{
|
||||
haID: []*route.Route{
|
||||
{
|
||||
ID: "r-" + route.ID(netID),
|
||||
NetID: netID,
|
||||
Network: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Enabled: true,
|
||||
SkipAutoApply: skipAutoApply,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateRouteSelectorFromManagement(t *testing.T) {
|
||||
t.Run("management auto-apply selects exit node without user selection", func(t *testing.T) {
|
||||
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
routes := exitNodeRoutes("exit1", false)
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
require.True(t, m.routeSelector.IsSelected("exit1"), "auto-apply exit node should be selected")
|
||||
require.Len(t, m.routeSelector.FilterSelectedExitNodes(routes), 1, "selected exit node should pass the filter")
|
||||
})
|
||||
|
||||
t.Run("management SkipAutoApply leaves exit node deselected", func(t *testing.T) {
|
||||
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
routes := exitNodeRoutes("exit1", true)
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
require.False(t, m.routeSelector.IsSelected("exit1"), "SkipAutoApply exit node should not be selected")
|
||||
require.Empty(t, m.routeSelector.FilterSelectedExitNodes(routes), "deselected exit node should be filtered out")
|
||||
})
|
||||
|
||||
t.Run("user selection is not overridden by management", func(t *testing.T) {
|
||||
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
require.NoError(t, m.routeSelector.SelectRoutes([]route.NetID{"exit1"}, true, []route.NetID{"exit1"}))
|
||||
routes := exitNodeRoutes("exit1", true)
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
require.True(t, m.routeSelector.IsSelected("exit1"), "explicit user selection must survive a management sync that wants to skip auto-apply")
|
||||
require.Len(t, m.routeSelector.FilterSelectedExitNodes(routes), 1, "user-selected exit node should pass the filter")
|
||||
})
|
||||
|
||||
t.Run("deselect-all is preserved across a management sync", func(t *testing.T) {
|
||||
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
m.routeSelector.DeselectAllRoutes()
|
||||
routes := exitNodeRoutes("exit1", false)
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
require.True(t, m.routeSelector.IsDeselectAll(), "an explicit deselect-all must not be cleared by management auto-apply")
|
||||
require.Empty(t, m.routeSelector.FilterSelectedExitNodes(routes), "no routes should be selected while deselect-all is set")
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEntryExists(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir)
|
||||
|
||||
content := []string{
|
||||
"1000 reserved",
|
||||
fmt.Sprintf("%d %s", NetbirdVPNTableID, NetbirdVPNTableName),
|
||||
"9999 other_table",
|
||||
}
|
||||
require.NoError(t, os.WriteFile(tempFilePath, []byte(strings.Join(content, "\n")), 0644))
|
||||
|
||||
file, err := os.Open(tempFilePath)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
assert.NoError(t, file.Close())
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
id int
|
||||
shouldExist bool
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "ExistsWithNetbirdPrefix",
|
||||
id: 7120,
|
||||
shouldExist: true,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "ExistsWithDifferentName",
|
||||
id: 1000,
|
||||
shouldExist: true,
|
||||
err: ErrTableIDExists,
|
||||
},
|
||||
{
|
||||
name: "DoesNotExist",
|
||||
id: 1234,
|
||||
shouldExist: false,
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
exists, err := entryExists(file, tc.id)
|
||||
if tc.err != nil {
|
||||
assert.ErrorIs(t, err, tc.err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tc.shouldExist, exists)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,191 @@
|
||||
//go:build (darwin || dragonfly || freebsd || netbsd || openbsd) && privileged
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func init() {
|
||||
testCases = append(testCases, []testCase{
|
||||
{
|
||||
name: "To more specific route without custom dialer via vpn",
|
||||
expectedInterface: expectedVPNint,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53),
|
||||
},
|
||||
}...)
|
||||
}
|
||||
|
||||
func TestConcurrentRoutes(t *testing.T) {
|
||||
baseIP := netip.MustParseAddr("192.0.2.0")
|
||||
|
||||
var intf *net.Interface
|
||||
var nexthop Nexthop
|
||||
|
||||
_, intf = setupDummyInterface(t)
|
||||
nexthop = Nexthop{netip.Addr{}, intf}
|
||||
|
||||
r := New(nil, nil)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 1024; i++ {
|
||||
wg.Add(1)
|
||||
go func(ip netip.Addr) {
|
||||
defer wg.Done()
|
||||
prefix := netip.PrefixFrom(ip, 32)
|
||||
if err := r.addToRouteTable(prefix, nexthop); err != nil {
|
||||
t.Errorf("Failed to add route for %s: %v", prefix, err)
|
||||
}
|
||||
}(baseIP)
|
||||
baseIP = baseIP.Next()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
baseIP = netip.MustParseAddr("192.0.2.0")
|
||||
|
||||
for i := 0; i < 1024; i++ {
|
||||
wg.Add(1)
|
||||
go func(ip netip.Addr) {
|
||||
defer wg.Done()
|
||||
prefix := netip.PrefixFrom(ip, 32)
|
||||
if err := r.removeFromRouteTable(prefix, nexthop); err != nil {
|
||||
t.Errorf("Failed to remove route for %s: %v", prefix, err)
|
||||
}
|
||||
}(baseIP)
|
||||
baseIP = baseIP.Next()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
|
||||
t.Helper()
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
|
||||
require.NoError(t, err, "Failed to create loopback alias")
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
|
||||
assert.NoError(t, err, "Failed to remove loopback alias")
|
||||
})
|
||||
|
||||
return intf
|
||||
}
|
||||
|
||||
prefix, err := netip.ParsePrefix(ipAddressCIDR)
|
||||
require.NoError(t, err, "Failed to parse prefix")
|
||||
|
||||
netIntf, err := net.InterfaceByName(intf)
|
||||
require.NoError(t, err, "Failed to get interface by name")
|
||||
|
||||
nexthop := Nexthop{netip.Addr{}, netIntf}
|
||||
|
||||
r := New(nil, nil)
|
||||
err = r.addToRouteTable(prefix, nexthop)
|
||||
require.NoError(t, err, "Failed to add route to table")
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := r.removeFromRouteTable(prefix, nexthop)
|
||||
assert.NoError(t, err, "Failed to remove route from table")
|
||||
})
|
||||
|
||||
return intf
|
||||
}
|
||||
|
||||
func addDummyRoute(t *testing.T, dstCIDR string, gw netip.Addr, _ string) {
|
||||
t.Helper()
|
||||
|
||||
var originalNexthop net.IP
|
||||
if dstCIDR == "0.0.0.0/0" {
|
||||
var err error
|
||||
originalNexthop, err = fetchOriginalGateway()
|
||||
if err != nil {
|
||||
t.Logf("Failed to fetch original gateway: %v", err)
|
||||
}
|
||||
|
||||
if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil {
|
||||
t.Logf("Failed to delete route: %v, output: %s", err, output)
|
||||
}
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
if originalNexthop != nil {
|
||||
err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run()
|
||||
assert.NoError(t, err, "Failed to restore original route")
|
||||
}
|
||||
})
|
||||
|
||||
err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run()
|
||||
require.NoError(t, err, "Failed to add route")
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := exec.Command("route", "delete", "-net", dstCIDR).Run()
|
||||
assert.NoError(t, err, "Failed to remove route")
|
||||
})
|
||||
}
|
||||
|
||||
func fetchOriginalGateway() (net.IP, error) {
|
||||
output, err := exec.Command("route", "-n", "get", "default").CombinedOutput()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output))
|
||||
if len(matches) == 0 {
|
||||
return nil, fmt.Errorf("gateway not found")
|
||||
}
|
||||
|
||||
return net.ParseIP(matches[1]), nil
|
||||
}
|
||||
|
||||
// setupDummyInterface creates a dummy tun interface for FreeBSD route testing
|
||||
func setupDummyInterface(t *testing.T) (netip.Addr, *net.Interface) {
|
||||
t.Helper()
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), &net.Interface{Name: "lo0"}
|
||||
}
|
||||
|
||||
output, err := exec.Command("ifconfig", "tun", "create").CombinedOutput()
|
||||
require.NoError(t, err, "Failed to create tun interface: %s", string(output))
|
||||
|
||||
tunName := strings.TrimSpace(string(output))
|
||||
|
||||
output, err = exec.Command("ifconfig", tunName, "192.168.1.1", "netmask", "255.255.0.0", "192.168.1.2", "up").CombinedOutput()
|
||||
require.NoError(t, err, "Failed to configure tun interface: %s", string(output))
|
||||
|
||||
intf, err := net.InterfaceByName(tunName)
|
||||
require.NoError(t, err, "Failed to get interface by name")
|
||||
|
||||
t.Cleanup(func() {
|
||||
if err := exec.Command("ifconfig", tunName, "destroy").Run(); err != nil {
|
||||
t.Logf("Failed to destroy tun interface %s: %v", tunName, err)
|
||||
}
|
||||
})
|
||||
|
||||
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), intf
|
||||
}
|
||||
|
||||
func setupDummyInterfacesAndRoutes(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
|
||||
addDummyRoute(t, "0.0.0.0/0", netip.AddrFrom4([4]byte{192, 168, 0, 1}), defaultDummy)
|
||||
|
||||
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
|
||||
addDummyRoute(t, "10.0.0.0/8", netip.AddrFrom4([4]byte{192, 168, 1, 1}), otherDummy)
|
||||
}
|
||||
@@ -3,79 +3,24 @@
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/route"
|
||||
)
|
||||
|
||||
// Interface names used by the shared routing test fixtures. Kept untagged (no
|
||||
// privileged build tag) so the non-privileged test files in this package compile.
|
||||
//
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
var expectedVPNint = "utun100"
|
||||
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
var expectedExternalInt = "lo0"
|
||||
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
var expectedInternalInt = "lo0"
|
||||
|
||||
func init() {
|
||||
testCases = append(testCases, []testCase{
|
||||
{
|
||||
name: "To more specific route without custom dialer via vpn",
|
||||
expectedInterface: expectedVPNint,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53),
|
||||
},
|
||||
}...)
|
||||
}
|
||||
|
||||
func TestConcurrentRoutes(t *testing.T) {
|
||||
baseIP := netip.MustParseAddr("192.0.2.0")
|
||||
|
||||
var intf *net.Interface
|
||||
var nexthop Nexthop
|
||||
|
||||
_, intf = setupDummyInterface(t)
|
||||
nexthop = Nexthop{netip.Addr{}, intf}
|
||||
|
||||
r := New(nil, nil)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 1024; i++ {
|
||||
wg.Add(1)
|
||||
go func(ip netip.Addr) {
|
||||
defer wg.Done()
|
||||
prefix := netip.PrefixFrom(ip, 32)
|
||||
if err := r.addToRouteTable(prefix, nexthop); err != nil {
|
||||
t.Errorf("Failed to add route for %s: %v", prefix, err)
|
||||
}
|
||||
}(baseIP)
|
||||
baseIP = baseIP.Next()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
baseIP = netip.MustParseAddr("192.0.2.0")
|
||||
|
||||
for i := 0; i < 1024; i++ {
|
||||
wg.Add(1)
|
||||
go func(ip netip.Addr) {
|
||||
defer wg.Done()
|
||||
prefix := netip.PrefixFrom(ip, 32)
|
||||
if err := r.removeFromRouteTable(prefix, nexthop); err != nil {
|
||||
t.Errorf("Failed to remove route for %s: %v", prefix, err)
|
||||
}
|
||||
}(baseIP)
|
||||
baseIP = baseIP.Next()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestBits(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -122,122 +67,3 @@ func TestBits(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
|
||||
t.Helper()
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
|
||||
require.NoError(t, err, "Failed to create loopback alias")
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
|
||||
assert.NoError(t, err, "Failed to remove loopback alias")
|
||||
})
|
||||
|
||||
return intf
|
||||
}
|
||||
|
||||
prefix, err := netip.ParsePrefix(ipAddressCIDR)
|
||||
require.NoError(t, err, "Failed to parse prefix")
|
||||
|
||||
netIntf, err := net.InterfaceByName(intf)
|
||||
require.NoError(t, err, "Failed to get interface by name")
|
||||
|
||||
nexthop := Nexthop{netip.Addr{}, netIntf}
|
||||
|
||||
r := New(nil, nil)
|
||||
err = r.addToRouteTable(prefix, nexthop)
|
||||
require.NoError(t, err, "Failed to add route to table")
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := r.removeFromRouteTable(prefix, nexthop)
|
||||
assert.NoError(t, err, "Failed to remove route from table")
|
||||
})
|
||||
|
||||
return intf
|
||||
}
|
||||
|
||||
func addDummyRoute(t *testing.T, dstCIDR string, gw netip.Addr, _ string) {
|
||||
t.Helper()
|
||||
|
||||
var originalNexthop net.IP
|
||||
if dstCIDR == "0.0.0.0/0" {
|
||||
var err error
|
||||
originalNexthop, err = fetchOriginalGateway()
|
||||
if err != nil {
|
||||
t.Logf("Failed to fetch original gateway: %v", err)
|
||||
}
|
||||
|
||||
if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil {
|
||||
t.Logf("Failed to delete route: %v, output: %s", err, output)
|
||||
}
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
if originalNexthop != nil {
|
||||
err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run()
|
||||
assert.NoError(t, err, "Failed to restore original route")
|
||||
}
|
||||
})
|
||||
|
||||
err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run()
|
||||
require.NoError(t, err, "Failed to add route")
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := exec.Command("route", "delete", "-net", dstCIDR).Run()
|
||||
assert.NoError(t, err, "Failed to remove route")
|
||||
})
|
||||
}
|
||||
|
||||
func fetchOriginalGateway() (net.IP, error) {
|
||||
output, err := exec.Command("route", "-n", "get", "default").CombinedOutput()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output))
|
||||
if len(matches) == 0 {
|
||||
return nil, fmt.Errorf("gateway not found")
|
||||
}
|
||||
|
||||
return net.ParseIP(matches[1]), nil
|
||||
}
|
||||
|
||||
// setupDummyInterface creates a dummy tun interface for FreeBSD route testing
|
||||
func setupDummyInterface(t *testing.T) (netip.Addr, *net.Interface) {
|
||||
t.Helper()
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), &net.Interface{Name: "lo0"}
|
||||
}
|
||||
|
||||
output, err := exec.Command("ifconfig", "tun", "create").CombinedOutput()
|
||||
require.NoError(t, err, "Failed to create tun interface: %s", string(output))
|
||||
|
||||
tunName := strings.TrimSpace(string(output))
|
||||
|
||||
output, err = exec.Command("ifconfig", tunName, "192.168.1.1", "netmask", "255.255.0.0", "192.168.1.2", "up").CombinedOutput()
|
||||
require.NoError(t, err, "Failed to configure tun interface: %s", string(output))
|
||||
|
||||
intf, err := net.InterfaceByName(tunName)
|
||||
require.NoError(t, err, "Failed to get interface by name")
|
||||
|
||||
t.Cleanup(func() {
|
||||
if err := exec.Command("ifconfig", tunName, "destroy").Run(); err != nil {
|
||||
t.Logf("Failed to destroy tun interface %s: %v", tunName, err)
|
||||
}
|
||||
})
|
||||
|
||||
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), intf
|
||||
}
|
||||
|
||||
func setupDummyInterfacesAndRoutes(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
|
||||
addDummyRoute(t, "0.0.0.0/0", netip.AddrFrom4([4]byte{192, 168, 0, 1}), defaultDummy)
|
||||
|
||||
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
|
||||
addDummyRoute(t, "10.0.0.0/8", netip.AddrFrom4([4]byte{192, 168, 1, 1}), otherDummy)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
//go:build !android && !ios
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
// dialer is shared by the per-platform routing test cases. Kept untagged (no
|
||||
// privileged build tag) so the non-privileged test files compile on every platform.
|
||||
//
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
type dialer interface {
|
||||
Dial(network, address string) (net.Conn, error)
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !android && !ios
|
||||
//go:build !android && !ios && privileged
|
||||
|
||||
package systemops
|
||||
|
||||
@@ -26,11 +26,6 @@ import (
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
type dialer interface {
|
||||
Dial(network, address string) (net.Conn, error)
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
func TestAddVPNRoute(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -515,125 +510,3 @@ func setupTestEnv(t *testing.T) {
|
||||
// unique route in vpn table
|
||||
setupRouteAndCleanup(t, r, netip.MustParsePrefix("172.16.0.0/12"), intf)
|
||||
}
|
||||
|
||||
func TestIsVpnRoute(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr string
|
||||
vpnRoutes []string
|
||||
localRoutes []string
|
||||
expectedVpn bool
|
||||
expectedPrefix netip.Prefix
|
||||
}{
|
||||
{
|
||||
name: "Match in VPN routes",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Match in local routes",
|
||||
addr: "10.1.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("10.0.0.0/8"),
|
||||
},
|
||||
{
|
||||
name: "No match",
|
||||
addr: "172.16.0.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.Prefix{},
|
||||
},
|
||||
{
|
||||
name: "Default route ignored",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Default route matches but ignored",
|
||||
addr: "172.16.1.1",
|
||||
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.Prefix{},
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match local",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.0.0/16"},
|
||||
localRoutes: []string{"192.168.1.0/24"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match local multiple",
|
||||
addr: "192.168.0.1",
|
||||
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
|
||||
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26", "192.168.0.0/28"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.0.0/28"),
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match vpn",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"192.168.0.0/16"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match vpn multiple",
|
||||
addr: "192.168.0.1",
|
||||
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
|
||||
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.0.0/27"),
|
||||
},
|
||||
{
|
||||
name: "Duplicate prefix in both",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"192.168.1.0/24"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addr, err := netip.ParseAddr(tt.addr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse address %s: %v", tt.addr, err)
|
||||
}
|
||||
|
||||
var vpnRoutes, localRoutes []netip.Prefix
|
||||
for _, route := range tt.vpnRoutes {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse VPN route %s: %v", route, err)
|
||||
}
|
||||
vpnRoutes = append(vpnRoutes, prefix)
|
||||
}
|
||||
|
||||
for _, route := range tt.localRoutes {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse local route %s: %v", route, err)
|
||||
}
|
||||
localRoutes = append(localRoutes, prefix)
|
||||
}
|
||||
|
||||
isVpn, matchedPrefix := isVpnRoute(addr, vpnRoutes, localRoutes)
|
||||
assert.Equal(t, tt.expectedVpn, isVpn, "isVpnRoute should return expectedVpn value")
|
||||
assert.Equal(t, tt.expectedPrefix, matchedPrefix, "isVpnRoute should return expectedVpn prefix")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,132 @@
|
||||
//go:build !android && !ios
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIsVpnRoute(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr string
|
||||
vpnRoutes []string
|
||||
localRoutes []string
|
||||
expectedVpn bool
|
||||
expectedPrefix netip.Prefix
|
||||
}{
|
||||
{
|
||||
name: "Match in VPN routes",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Match in local routes",
|
||||
addr: "10.1.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("10.0.0.0/8"),
|
||||
},
|
||||
{
|
||||
name: "No match",
|
||||
addr: "172.16.0.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.Prefix{},
|
||||
},
|
||||
{
|
||||
name: "Default route ignored",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Default route matches but ignored",
|
||||
addr: "172.16.1.1",
|
||||
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.Prefix{},
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match local",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.0.0/16"},
|
||||
localRoutes: []string{"192.168.1.0/24"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match local multiple",
|
||||
addr: "192.168.0.1",
|
||||
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
|
||||
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26", "192.168.0.0/28"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.0.0/28"),
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match vpn",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"192.168.0.0/16"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match vpn multiple",
|
||||
addr: "192.168.0.1",
|
||||
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
|
||||
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.0.0/27"),
|
||||
},
|
||||
{
|
||||
name: "Duplicate prefix in both",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"192.168.1.0/24"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addr, err := netip.ParseAddr(tt.addr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse address %s: %v", tt.addr, err)
|
||||
}
|
||||
|
||||
var vpnRoutes, localRoutes []netip.Prefix
|
||||
for _, route := range tt.vpnRoutes {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse VPN route %s: %v", route, err)
|
||||
}
|
||||
vpnRoutes = append(vpnRoutes, prefix)
|
||||
}
|
||||
|
||||
for _, route := range tt.localRoutes {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse local route %s: %v", route, err)
|
||||
}
|
||||
localRoutes = append(localRoutes, prefix)
|
||||
}
|
||||
|
||||
isVpn, matchedPrefix := isVpnRoute(addr, vpnRoutes, localRoutes)
|
||||
assert.Equal(t, tt.expectedVpn, isVpn, "isVpnRoute should return expectedVpn value")
|
||||
assert.Equal(t, tt.expectedPrefix, matchedPrefix, "isVpnRoute should return expectedVpn prefix")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,13 +1,10 @@
|
||||
//go:build !android
|
||||
//go:build !android && privileged
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
@@ -18,10 +15,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
)
|
||||
|
||||
var expectedVPNint = "wgtest0"
|
||||
var expectedExternalInt = "dummyext0"
|
||||
var expectedInternalInt = "dummyint0"
|
||||
|
||||
func init() {
|
||||
testCases = append(testCases, []testCase{
|
||||
{
|
||||
@@ -33,62 +26,6 @@ func init() {
|
||||
}...)
|
||||
}
|
||||
|
||||
func TestEntryExists(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir)
|
||||
|
||||
content := []string{
|
||||
"1000 reserved",
|
||||
fmt.Sprintf("%d %s", NetbirdVPNTableID, NetbirdVPNTableName),
|
||||
"9999 other_table",
|
||||
}
|
||||
require.NoError(t, os.WriteFile(tempFilePath, []byte(strings.Join(content, "\n")), 0644))
|
||||
|
||||
file, err := os.Open(tempFilePath)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
assert.NoError(t, file.Close())
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
id int
|
||||
shouldExist bool
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "ExistsWithNetbirdPrefix",
|
||||
id: 7120,
|
||||
shouldExist: true,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "ExistsWithDifferentName",
|
||||
id: 1000,
|
||||
shouldExist: true,
|
||||
err: ErrTableIDExists,
|
||||
},
|
||||
{
|
||||
name: "DoesNotExist",
|
||||
id: 1234,
|
||||
shouldExist: false,
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
exists, err := entryExists(file, tc.id)
|
||||
if tc.err != nil {
|
||||
assert.ErrorIs(t, err, tc.err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tc.shouldExist, exists)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package systemops
|
||||
|
||||
// Interface names used by the shared routing test fixtures. Kept untagged (no
|
||||
// privileged build tag) so the non-privileged test files in this package compile.
|
||||
//
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
var expectedVPNint = "wgtest0"
|
||||
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
var expectedExternalInt = "dummyext0"
|
||||
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
var expectedInternalInt = "dummyint0"
|
||||
@@ -0,0 +1,83 @@
|
||||
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
// Shared, non-privileged routing test fixtures. The privileged TestRouting (and its
|
||||
// per-platform init() appenders) consume these; they live here so the unprivileged
|
||||
// BSD/darwin test files compile without the privileged build tag.
|
||||
|
||||
type PacketExpectation struct {
|
||||
SrcIP net.IP
|
||||
DstIP net.IP
|
||||
SrcPort int
|
||||
DstPort int
|
||||
UDP bool
|
||||
TCP bool
|
||||
}
|
||||
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
type testCase struct {
|
||||
name string
|
||||
expectedInterface string
|
||||
dialer dialer
|
||||
expectedPacket PacketExpectation
|
||||
}
|
||||
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
var testCases = []testCase{
|
||||
{
|
||||
name: "To external host without custom dialer via vpn",
|
||||
expectedInterface: expectedVPNint,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53),
|
||||
},
|
||||
{
|
||||
name: "To external host with custom dialer via physical interface",
|
||||
expectedInterface: expectedExternalInt,
|
||||
dialer: nbnet.NewDialer(),
|
||||
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53),
|
||||
},
|
||||
|
||||
{
|
||||
name: "To duplicate internal route with custom dialer via physical interface",
|
||||
expectedInterface: expectedInternalInt,
|
||||
dialer: nbnet.NewDialer(),
|
||||
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
|
||||
},
|
||||
{
|
||||
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
|
||||
expectedInterface: expectedInternalInt,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
|
||||
},
|
||||
|
||||
{
|
||||
name: "To unique vpn route with custom dialer via physical interface",
|
||||
expectedInterface: expectedExternalInt,
|
||||
dialer: nbnet.NewDialer(),
|
||||
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53),
|
||||
},
|
||||
{
|
||||
name: "To unique vpn route without custom dialer via vpn",
|
||||
expectedInterface: expectedVPNint,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53),
|
||||
},
|
||||
}
|
||||
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation {
|
||||
return PacketExpectation{
|
||||
SrcIP: net.ParseIP(srcIP),
|
||||
DstIP: net.ParseIP(dstIP),
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
UDP: true,
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly
|
||||
//go:build ((linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly) && privileged
|
||||
|
||||
package systemops
|
||||
|
||||
@@ -20,63 +20,6 @@ import (
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
type PacketExpectation struct {
|
||||
SrcIP net.IP
|
||||
DstIP net.IP
|
||||
SrcPort int
|
||||
DstPort int
|
||||
UDP bool
|
||||
TCP bool
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
expectedInterface string
|
||||
dialer dialer
|
||||
expectedPacket PacketExpectation
|
||||
}
|
||||
|
||||
var testCases = []testCase{
|
||||
{
|
||||
name: "To external host without custom dialer via vpn",
|
||||
expectedInterface: expectedVPNint,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53),
|
||||
},
|
||||
{
|
||||
name: "To external host with custom dialer via physical interface",
|
||||
expectedInterface: expectedExternalInt,
|
||||
dialer: nbnet.NewDialer(),
|
||||
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53),
|
||||
},
|
||||
|
||||
{
|
||||
name: "To duplicate internal route with custom dialer via physical interface",
|
||||
expectedInterface: expectedInternalInt,
|
||||
dialer: nbnet.NewDialer(),
|
||||
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
|
||||
},
|
||||
{
|
||||
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
|
||||
expectedInterface: expectedInternalInt,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
|
||||
},
|
||||
|
||||
{
|
||||
name: "To unique vpn route with custom dialer via physical interface",
|
||||
expectedInterface: expectedExternalInt,
|
||||
dialer: nbnet.NewDialer(),
|
||||
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53),
|
||||
},
|
||||
{
|
||||
name: "To unique vpn route without custom dialer via vpn",
|
||||
expectedInterface: expectedVPNint,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53),
|
||||
},
|
||||
}
|
||||
|
||||
func TestRouting(t *testing.T) {
|
||||
nbnet.Init()
|
||||
for _, tc := range testCases {
|
||||
@@ -102,16 +45,6 @@ func TestRouting(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation {
|
||||
return PacketExpectation{
|
||||
SrcIP: net.ParseIP(srcIP),
|
||||
DstIP: net.ParseIP(dstIP),
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
UDP: true,
|
||||
}
|
||||
}
|
||||
|
||||
func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build privileged
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback
|
||||
// interface so route lookups for global IPv6 prefixes resolve in environments
|
||||
// without v6 connectivity. If a default already exists it is left alone.
|
||||
//
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
func ensureIPv6DefaultRoute(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build linux && !android
|
||||
//go:build linux && !android && privileged
|
||||
|
||||
package systemops
|
||||
|
||||
|
||||
@@ -8,11 +8,14 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
const loopbackIfaceWindows = "Loopback Pseudo-Interface 1"
|
||||
|
||||
// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback
|
||||
// interface so route lookups for global IPv6 prefixes resolve in environments
|
||||
// without v6 connectivity. If a default already exists it is left alone.
|
||||
//
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
func ensureIPv6DefaultRoute(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -116,6 +116,14 @@ func (rs *RouteSelector) DeselectAllRoutes() {
|
||||
clear(rs.selectedRoutes)
|
||||
}
|
||||
|
||||
// IsDeselectAll reports whether the user has explicitly deselected all routes.
|
||||
func (rs *RouteSelector) IsDeselectAll() bool {
|
||||
rs.mu.RLock()
|
||||
defer rs.mu.RUnlock()
|
||||
|
||||
return rs.deselectAll
|
||||
}
|
||||
|
||||
// IsSelected checks if a specific route is selected.
|
||||
func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
|
||||
rs.mu.RLock()
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
package shell
|
||||
|
||||
import (
|
||||
"os/user"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestIntegration_ShellLookupChain tests the full shell resolution chain
|
||||
// (getShellFromPasswd -> getShellFromGetent -> $SHELL -> default) on Unix.
|
||||
func TestIntegration_ShellLookupChain(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Unix shell lookup not applicable on Windows")
|
||||
}
|
||||
|
||||
current, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
// getUserShell is the top-level function used by the SSH server.
|
||||
shell := GetUserShell(current.Uid)
|
||||
require.NotEmpty(t, shell, "getUserShell must always return a shell")
|
||||
assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell)
|
||||
}
|
||||
50
client/mdm/canonical_loaders.go
Normal file
50
client/mdm/canonical_loaders.go
Normal file
@@ -0,0 +1,50 @@
|
||||
//go:build windows || darwin
|
||||
|
||||
package mdm
|
||||
|
||||
import "strings"
|
||||
|
||||
// allKeys is the set of recognised MDM keys. Unknown keys in a managed
|
||||
// configuration are ignored but logged. Lives in this build-tagged file
|
||||
// (windows || darwin) because only desktop loaders need the
|
||||
// canonicalisation table that consumes it; including it unconditionally
|
||||
// would trigger the `unused` golangci-lint check on platforms that
|
||||
// don't import canonical_loaders.go.
|
||||
var allKeys = []string{
|
||||
KeyManagementURL,
|
||||
KeyDisableUpdateSettings,
|
||||
KeyDisableProfiles,
|
||||
KeyDisableNetworks,
|
||||
KeyDisableClientRoutes,
|
||||
KeyDisableServerRoutes,
|
||||
KeyBlockInbound,
|
||||
KeyDisableMetricsCollection,
|
||||
KeyAllowServerSSH,
|
||||
KeyDisableAutoConnect,
|
||||
KeyPreSharedKey,
|
||||
KeyRosenpassEnabled,
|
||||
KeyRosenpassPermissive,
|
||||
KeyWireguardPort,
|
||||
KeySplitTunnelMode,
|
||||
KeySplitTunnelApps,
|
||||
}
|
||||
|
||||
// canonicalKey maps the lowercase form of a managed-config value name to
|
||||
// its canonical mdm.Key* form. Admins commonly write PascalCase value
|
||||
// names in ADMX / Group Policy ("ManagementURL"); the iOS/AppConfig and
|
||||
// macOS plist conventions are camelCase ("managementURL"); both must
|
||||
// resolve to the same Policy lookup.
|
||||
//
|
||||
// Lives in a desktop-loader-only file (build tag `windows || darwin`)
|
||||
// because no other build path consumes it. Linux / FreeBSD / mobile
|
||||
// builds don't ship a platform loader that reads arbitrary-case key
|
||||
// names, so they don't need the canonicalisation table — and including
|
||||
// the var unconditionally would trigger the `unused` golangci-lint
|
||||
// check on those platforms.
|
||||
var canonicalKey = func() map[string]string {
|
||||
m := make(map[string]string, len(allKeys))
|
||||
for _, k := range allKeys {
|
||||
m[strings.ToLower(k)] = k
|
||||
}
|
||||
return m
|
||||
}()
|
||||
247
client/mdm/policy.go
Normal file
247
client/mdm/policy.go
Normal file
@@ -0,0 +1,247 @@
|
||||
// Package mdm reads MDM-managed configuration from platform-native sources
|
||||
// (plist on macOS, registry on Windows, UserDefaults on iOS,
|
||||
// RestrictionsManager on Android). The returned Policy is consumed by
|
||||
// profilemanager.Config.apply() as the highest-priority override layer.
|
||||
//
|
||||
// An empty Policy (no source present, or source present with zero keys)
|
||||
// means no MDM enforcement is active and the client behaves as if the
|
||||
// feature did not exist.
|
||||
package mdm
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strconv"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Well-known policy keys. Names mirror the corresponding ConfigInput Go field
|
||||
// names (lowerCamelCase) so the daemon can map a Policy key directly to a
|
||||
// configuration field.
|
||||
const (
|
||||
KeyManagementURL = "managementURL"
|
||||
KeyDisableUpdateSettings = "disableUpdateSettings"
|
||||
KeyDisableProfiles = "disableProfiles"
|
||||
KeyDisableNetworks = "disableNetworks"
|
||||
KeyDisableClientRoutes = "disableClientRoutes"
|
||||
KeyDisableServerRoutes = "disableServerRoutes"
|
||||
KeyBlockInbound = "blockInbound"
|
||||
KeyDisableMetricsCollection = "disableMetricsCollection"
|
||||
KeyAllowServerSSH = "allowServerSSH"
|
||||
KeyDisableAutoConnect = "disableAutoConnect"
|
||||
KeyPreSharedKey = "preSharedKey"
|
||||
KeyRosenpassEnabled = "rosenpassEnabled"
|
||||
KeyRosenpassPermissive = "rosenpassPermissive"
|
||||
KeyWireguardPort = "wireguardPort"
|
||||
|
||||
// Split tunnel is modeled as a single conceptual policy with two
|
||||
// registry/plist values. KeySplitTunnelMode is the discriminator
|
||||
// ("allow" or "disallow"); KeySplitTunnelApps is a comma-separated
|
||||
// list of package names. The values are mutually exclusive by
|
||||
// construction — only one mode can be set at a time.
|
||||
KeySplitTunnelMode = "splitTunnelMode"
|
||||
KeySplitTunnelApps = "splitTunnelApps"
|
||||
)
|
||||
|
||||
// Split-tunnel mode literals (KeySplitTunnelMode values).
|
||||
const (
|
||||
SplitTunnelModeAllow = "allow"
|
||||
SplitTunnelModeDisallow = "disallow"
|
||||
)
|
||||
|
||||
// SecretKeys lists keys whose values must be redacted in logs.
|
||||
var SecretKeys = map[string]struct{}{
|
||||
KeyPreSharedKey: {},
|
||||
}
|
||||
|
||||
// boolStringLiterals enumerates the textual boolean encodings the
|
||||
// platform loaders may produce (Windows REG_SZ "true", iOS / Android
|
||||
// managed-config booleans-as-strings, etc.). Lookup keeps GetBool flat
|
||||
// (no nested switch on the string case).
|
||||
var boolStringLiterals = map[string]bool{
|
||||
"true": true,
|
||||
"1": true,
|
||||
"yes": true,
|
||||
"false": false,
|
||||
"0": false,
|
||||
"no": false,
|
||||
}
|
||||
|
||||
|
||||
// Policy holds MDM-managed settings read from the platform source. A nil or
|
||||
// empty Policy means no enforcement is active.
|
||||
type Policy struct {
|
||||
values map[string]any
|
||||
}
|
||||
|
||||
// NewPolicy constructs a Policy from a key→value map. Pass nil or an
|
||||
// empty map to construct an empty (no-enforcement) Policy. The returned
|
||||
// *Policy is always non-nil.
|
||||
func NewPolicy(values map[string]any) *Policy {
|
||||
if values == nil {
|
||||
values = map[string]any{}
|
||||
}
|
||||
return &Policy{values: values}
|
||||
}
|
||||
|
||||
// LoadPolicy reads the platform-native MDM configuration. Returns an
|
||||
// empty (but non-nil) Policy when no source is present, the source is
|
||||
// empty, or the platform is unsupported.
|
||||
//
|
||||
// Diagnostic logging differentiates the three states:
|
||||
// - source absent / unsupported platform: trace log only
|
||||
// - source present, zero keys: info "MDM enrolled (no managed keys)"
|
||||
// - source present, N keys: info "MDM enrolled with N managed keys: [...]"
|
||||
func LoadPolicy() *Policy {
|
||||
values, err := loadPlatformPolicy()
|
||||
if err != nil {
|
||||
log.Tracef("MDM policy load: %v", err)
|
||||
return &Policy{values: map[string]any{}}
|
||||
}
|
||||
if values == nil {
|
||||
return &Policy{values: map[string]any{}}
|
||||
}
|
||||
if len(values) == 0 {
|
||||
log.Info("MDM enrolled (no managed keys)")
|
||||
} else {
|
||||
log.Infof("MDM enrolled with %d managed key(s): %v", len(values), sortedKeys(values))
|
||||
}
|
||||
return &Policy{values: values}
|
||||
}
|
||||
|
||||
// IsEmpty reports whether the Policy has no managed keys.
|
||||
func (p *Policy) IsEmpty() bool {
|
||||
return p == nil || len(p.values) == 0
|
||||
}
|
||||
|
||||
// HasKey reports whether the given key is MDM-managed.
|
||||
func (p *Policy) HasKey(key string) bool {
|
||||
if p == nil {
|
||||
return false
|
||||
}
|
||||
_, ok := p.values[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ManagedKeys returns the sorted list of managed key names. Returns an empty
|
||||
// slice (not nil) on an empty Policy.
|
||||
func (p *Policy) ManagedKeys() []string {
|
||||
if p == nil {
|
||||
return []string{}
|
||||
}
|
||||
return sortedKeys(p.values)
|
||||
}
|
||||
|
||||
// GetString returns the managed value for key coerced to string, and whether
|
||||
// the key was set. A non-string value returns ("", false).
|
||||
func (p *Policy) GetString(key string) (string, bool) {
|
||||
if p == nil {
|
||||
return "", false
|
||||
}
|
||||
v, ok := p.values[key]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
s, ok := v.(string)
|
||||
if !ok || s == "" {
|
||||
return "", false
|
||||
}
|
||||
return s, true
|
||||
}
|
||||
|
||||
// GetBool returns the managed value for key coerced to bool, and whether the
|
||||
// key was set. Accepts native bool and string literals "true"/"false"/"1"/"0".
|
||||
func (p *Policy) GetBool(key string) (bool, bool) {
|
||||
if p == nil {
|
||||
return false, false
|
||||
}
|
||||
v, ok := p.values[key]
|
||||
if !ok {
|
||||
return false, false
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case bool:
|
||||
return t, true
|
||||
case string:
|
||||
b, known := boolStringLiterals[t]
|
||||
return b, known
|
||||
case int:
|
||||
return t != 0, true
|
||||
case int64:
|
||||
return t != 0, true
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
|
||||
// GetInt returns the managed value for key as int64, and whether the key
|
||||
// was set. Accepts native int / int64 (as produced by the Windows registry
|
||||
// loader for REG_DWORD/REG_QWORD) and numeric strings (decimal).
|
||||
func (p *Policy) GetInt(key string) (int64, bool) {
|
||||
if p == nil {
|
||||
return 0, false
|
||||
}
|
||||
v, ok := p.values[key]
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case int64:
|
||||
return t, true
|
||||
case int:
|
||||
return int64(t), true
|
||||
case int32:
|
||||
return int64(t), true
|
||||
case uint64:
|
||||
return int64(t), true
|
||||
case float64:
|
||||
return int64(t), true
|
||||
case string:
|
||||
if n, err := strconv.ParseInt(t, 10, 64); err == nil {
|
||||
return n, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetStringSlice returns the managed value for key as []string, and whether
|
||||
// the key was set. Accepts []string, []any (of strings), and a single string
|
||||
// (treated as a one-element list).
|
||||
func (p *Policy) GetStringSlice(key string) ([]string, bool) {
|
||||
if p == nil {
|
||||
return nil, false
|
||||
}
|
||||
v, ok := p.values[key]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case []string:
|
||||
return append([]string(nil), t...), true
|
||||
case []any:
|
||||
out := make([]string, 0, len(t))
|
||||
for _, item := range t {
|
||||
s, ok := item.(string)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
out = append(out, s)
|
||||
}
|
||||
return out, true
|
||||
case string:
|
||||
return []string{t}, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// sortedKeys returns the keys of m as a deterministic, lexicographically
|
||||
// sorted slice. Used internally by Policy.ManagedKeys and LoadPolicy's
|
||||
// diagnostic log line so callers see a stable key order across runs
|
||||
// regardless of Go's randomised map iteration.
|
||||
func sortedKeys(m map[string]any) []string {
|
||||
out := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
out = append(out, k)
|
||||
}
|
||||
sort.Strings(out)
|
||||
return out
|
||||
}
|
||||
90
client/mdm/policy_darwin.go
Normal file
90
client/mdm/policy_darwin.go
Normal file
@@ -0,0 +1,90 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package mdm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"howett.net/plist"
|
||||
)
|
||||
|
||||
// policyPlistPath is the well-known location where macOS writes the
|
||||
// device-level mandatory MDM payload for NetBird. The path is fixed by
|
||||
// Apple convention: when an MDM provider (Jamf / Kandji / Mosyle /
|
||||
// Intune for Mac / Workspace ONE) pushes a Configuration Profile that
|
||||
// contains a com.apple.ManagedClient.preferences payload targeting the
|
||||
// bundle id io.netbird.client, the OS materializes the payload here.
|
||||
//
|
||||
// Read-only — only the OS (root) is supposed to write this file. The
|
||||
// loader sanity-checks the file mode and refuses to honour a world-
|
||||
// writable plist, as a defense against tampered installs.
|
||||
const policyPlistPath = "/Library/Managed Preferences/io.netbird.client.plist"
|
||||
|
||||
// loadPlatformPolicy reads the MDM-managed configuration from the macOS
|
||||
// managed-preferences plist at policyPlistPath. Returns:
|
||||
// - (nil, nil) when the plist is absent (device not MDM-enrolled for
|
||||
// NetBird, or admin has not yet pushed a payload)
|
||||
// - (map, nil) with N entries when N managed values are present
|
||||
// (N may be 0 — empty plist still signals enrollment to the caller)
|
||||
// - (nil, err) on permission / parse / safety errors (including
|
||||
// refusal to read a world-writable plist)
|
||||
//
|
||||
// Top-level plist keys are canonicalised case-insensitively to the
|
||||
// package's internal mdm.Key* names; unknown keys are logged and
|
||||
// skipped so a stray entry in the payload does not block startup.
|
||||
// Native plist value types map naturally onto the Policy accessor
|
||||
// expectations (GetString / GetBool / GetInt / GetStringSlice).
|
||||
func loadPlatformPolicy() (map[string]any, error) {
|
||||
f, err := os.Open(policyPlistPath)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
// Not enrolled for NetBird. Caller treats nil as
|
||||
// "no MDM source present".
|
||||
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see LoadPolicy.
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("open %s: %w", policyPlistPath, err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := f.Close(); closeErr != nil {
|
||||
log.Warnf("MDM close plist %s: %v", policyPlistPath, closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
info, err := f.Stat()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stat %s: %w", policyPlistPath, err)
|
||||
}
|
||||
// World-writable plist => tampered install. Refuse rather than
|
||||
// honour potentially attacker-controlled policy values.
|
||||
if info.Mode().Perm()&0o002 != 0 {
|
||||
return nil, fmt.Errorf("refusing to read world-writable MDM source %s (mode %o)",
|
||||
policyPlistPath, info.Mode().Perm())
|
||||
}
|
||||
|
||||
raw := make(map[string]any)
|
||||
if err := plist.NewDecoder(f).Decode(&raw); err != nil {
|
||||
return nil, fmt.Errorf("decode plist %s: %w", policyPlistPath, err)
|
||||
}
|
||||
|
||||
out := make(map[string]any, len(raw))
|
||||
for name, val := range raw {
|
||||
// macOS / AppConfig conventions both use camelCase for managed
|
||||
// preferences keys; canonicalize to the mdm.Key* form so a key
|
||||
// written as "ManagementURL" (PascalCase, rare on macOS but
|
||||
// possible if the admin reused an ADMX-style name) still
|
||||
// resolves.
|
||||
canonical, known := canonicalKey[strings.ToLower(name)]
|
||||
if !known {
|
||||
log.Warnf("MDM ignoring unknown plist key %s: %s", policyPlistPath, name)
|
||||
continue
|
||||
}
|
||||
out[canonical] = val
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
14
client/mdm/policy_mobile.go
Normal file
14
client/mdm/policy_mobile.go
Normal file
@@ -0,0 +1,14 @@
|
||||
//go:build ios || android
|
||||
|
||||
package mdm
|
||||
|
||||
// loadPlatformPolicy is unused on mobile: the native layer (Swift on iOS,
|
||||
// Kotlin/Java on Android) reads the OS managed-config store and pushes the
|
||||
// resulting dictionary in-process via a gomobile entry point that lands in
|
||||
// Phase 5 / Phase 6. The stub keeps the package compilable for mobile
|
||||
// builds and returns (nil, nil) — the platform-absent sentinel that
|
||||
// LoadPolicy in policy.go treats as "no MDM source present".
|
||||
func loadPlatformPolicy() (map[string]any, error) {
|
||||
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see LoadPolicy.
|
||||
return nil, nil
|
||||
}
|
||||
14
client/mdm/policy_other.go
Normal file
14
client/mdm/policy_other.go
Normal file
@@ -0,0 +1,14 @@
|
||||
//go:build !windows && !darwin && !ios && !android
|
||||
|
||||
package mdm
|
||||
|
||||
// loadPlatformPolicy returns no policy on platforms without an MDM channel
|
||||
// (Linux, FreeBSD). MDM enforcement is off and the client behaves as if
|
||||
// the feature did not exist. Returns (nil, nil) — the platform-absent
|
||||
// sentinel the caller (LoadPolicy in policy.go) treats as "no MDM
|
||||
// source present"; an error here would just translate to the same
|
||||
// outcome with an extra log line.
|
||||
func loadPlatformPolicy() (map[string]any, error) {
|
||||
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see LoadPolicy.
|
||||
return nil, nil
|
||||
}
|
||||
160
client/mdm/policy_test.go
Normal file
160
client/mdm/policy_test.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package mdm
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPolicy_NilSafe(t *testing.T) {
|
||||
var p *Policy
|
||||
assert.True(t, p.IsEmpty())
|
||||
assert.False(t, p.HasKey(KeyManagementURL))
|
||||
assert.Empty(t, p.ManagedKeys())
|
||||
|
||||
_, ok := p.GetString(KeyManagementURL)
|
||||
assert.False(t, ok)
|
||||
_, ok = p.GetBool(KeyDisableProfiles)
|
||||
assert.False(t, ok)
|
||||
_, ok = p.GetStringSlice(KeySplitTunnelApps)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestPolicy_Empty(t *testing.T) {
|
||||
p := NewPolicy(nil)
|
||||
require.NotNil(t, p)
|
||||
assert.True(t, p.IsEmpty())
|
||||
assert.False(t, p.HasKey(KeyManagementURL))
|
||||
assert.Empty(t, p.ManagedKeys())
|
||||
}
|
||||
|
||||
func TestPolicy_HasKey(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeyManagementURL: "https://corp.example.com",
|
||||
KeyDisableProfiles: true,
|
||||
})
|
||||
assert.False(t, p.IsEmpty())
|
||||
assert.True(t, p.HasKey(KeyManagementURL))
|
||||
assert.True(t, p.HasKey(KeyDisableProfiles))
|
||||
assert.False(t, p.HasKey(KeyPreSharedKey))
|
||||
}
|
||||
|
||||
func TestPolicy_ManagedKeysSorted(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeyDisableProfiles: true,
|
||||
KeyManagementURL: "https://x",
|
||||
KeyAllowServerSSH: false,
|
||||
})
|
||||
got := p.ManagedKeys()
|
||||
assert.Equal(t, []string{KeyAllowServerSSH, KeyDisableProfiles, KeyManagementURL}, got)
|
||||
}
|
||||
|
||||
func TestPolicy_GetString(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeyManagementURL: "https://corp.example.com",
|
||||
KeyDisableProfiles: true, // wrong type for GetString
|
||||
KeyPreSharedKey: "", // empty rejected
|
||||
})
|
||||
v, ok := p.GetString(KeyManagementURL)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "https://corp.example.com", v)
|
||||
|
||||
_, ok = p.GetString(KeyDisableProfiles)
|
||||
assert.False(t, ok, "non-string value must not be reported as string")
|
||||
|
||||
_, ok = p.GetString(KeyPreSharedKey)
|
||||
assert.False(t, ok, "empty string treated as unset")
|
||||
|
||||
_, ok = p.GetString("nonexistent")
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestPolicy_GetBool(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
raw any
|
||||
want bool
|
||||
ok bool
|
||||
}{
|
||||
{"native true", true, true, true},
|
||||
{"native false", false, false, true},
|
||||
{"string true", "true", true, true},
|
||||
{"string false", "false", false, true},
|
||||
{"string 1", "1", true, true},
|
||||
{"string 0", "0", false, true},
|
||||
{"string yes", "yes", true, true},
|
||||
{"string no", "no", false, true},
|
||||
{"int nonzero", 1, true, true},
|
||||
{"int zero", 0, false, true},
|
||||
{"int64 nonzero", int64(2), true, true},
|
||||
{"int64 zero", int64(0), false, true},
|
||||
{"string garbage", "maybe", false, false},
|
||||
{"float unsupported", 1.0, false, false},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{KeyDisableProfiles: c.raw})
|
||||
got, ok := p.GetBool(KeyDisableProfiles)
|
||||
assert.Equal(t, c.ok, ok)
|
||||
if c.ok {
|
||||
assert.Equal(t, c.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
_, ok := NewPolicy(nil).GetBool(KeyDisableProfiles)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestPolicy_GetStringSlice(t *testing.T) {
|
||||
t.Run("native string slice", func(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeySplitTunnelApps: []string{"com.a", "com.b"},
|
||||
})
|
||||
got, ok := p.GetStringSlice(KeySplitTunnelApps)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, []string{"com.a", "com.b"}, got)
|
||||
})
|
||||
|
||||
t.Run("any slice of strings", func(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeySplitTunnelApps: []any{"com.a", "com.b"},
|
||||
})
|
||||
got, ok := p.GetStringSlice(KeySplitTunnelApps)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, []string{"com.a", "com.b"}, got)
|
||||
})
|
||||
|
||||
t.Run("single string lifts to one-element slice", func(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeySplitTunnelApps: "com.a",
|
||||
})
|
||||
got, ok := p.GetStringSlice(KeySplitTunnelApps)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, []string{"com.a"}, got)
|
||||
})
|
||||
|
||||
t.Run("mixed any slice rejected", func(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeySplitTunnelApps: []any{"com.a", 1},
|
||||
})
|
||||
_, ok := p.GetStringSlice(KeySplitTunnelApps)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("missing key", func(t *testing.T) {
|
||||
p := NewPolicy(nil)
|
||||
_, ok := p.GetStringSlice(KeySplitTunnelApps)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadPolicy_PlatformStubReturnsEmpty(t *testing.T) {
|
||||
// loadPlatformPolicy is a stub on every OS for Phase 1. LoadPolicy must
|
||||
// degrade gracefully and never return nil.
|
||||
p := LoadPolicy()
|
||||
require.NotNil(t, p)
|
||||
assert.True(t, p.IsEmpty())
|
||||
assert.Empty(t, p.ManagedKeys())
|
||||
}
|
||||
108
client/mdm/policy_windows.go
Normal file
108
client/mdm/policy_windows.go
Normal file
@@ -0,0 +1,108 @@
|
||||
//go:build windows
|
||||
|
||||
package mdm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
)
|
||||
|
||||
// policyRegistryPath is the well-known MDM policy registry key for NetBird.
|
||||
// Admins push values here through Group Policy, Intune ADMX ingestion, an
|
||||
// Intune custom Registry CSP profile, or `reg add` during MSI deployment.
|
||||
// Listed in the project's docs/mdm/netbird.admx schema.
|
||||
const policyRegistryPath = `Software\Policies\NetBird`
|
||||
|
||||
// readRegistryValue reads a single value under policyRegistryPath and,
|
||||
// on success, stores the type-coerced result in out[canonical]. Type
|
||||
// coercion mirrors loadPlatformPolicy's documented mapping:
|
||||
// - REG_SZ / REG_EXPAND_SZ -> string (REG_EXPAND_SZ is expanded by the API)
|
||||
// - REG_DWORD / REG_QWORD -> int64
|
||||
// - REG_MULTI_SZ -> []string
|
||||
//
|
||||
// Unsupported value types and per-value read failures are logged at
|
||||
// warn level and skipped — one malformed value must not block the
|
||||
// surrounding loop. Extracted from loadPlatformPolicy to keep that
|
||||
// function's cognitive complexity in check.
|
||||
func readRegistryValue(k registry.Key, name, canonical string, out map[string]any) {
|
||||
_, valType, err := k.GetValue(name, nil)
|
||||
if err != nil {
|
||||
log.Warnf("MDM stat %s\\%s: %v", policyRegistryPath, name, err)
|
||||
return
|
||||
}
|
||||
switch valType {
|
||||
case registry.SZ, registry.EXPAND_SZ:
|
||||
if v, _, err := k.GetStringValue(name); err == nil {
|
||||
out[canonical] = v
|
||||
} else {
|
||||
log.Warnf("MDM read string %s\\%s: %v", policyRegistryPath, name, err)
|
||||
}
|
||||
case registry.DWORD, registry.QWORD:
|
||||
if v, _, err := k.GetIntegerValue(name); err == nil {
|
||||
// uint64 from the registry API; Policy.GetBool / GetInt
|
||||
// helpers consume int64, so narrow safely.
|
||||
out[canonical] = int64(v)
|
||||
} else {
|
||||
log.Warnf("MDM read int %s\\%s: %v", policyRegistryPath, name, err)
|
||||
}
|
||||
case registry.MULTI_SZ:
|
||||
if v, _, err := k.GetStringsValue(name); err == nil {
|
||||
out[canonical] = v
|
||||
} else {
|
||||
log.Warnf("MDM read multi-string %s\\%s: %v", policyRegistryPath, name, err)
|
||||
}
|
||||
default:
|
||||
log.Warnf("MDM ignoring unsupported registry value type %d at %s\\%s",
|
||||
valType, policyRegistryPath, name)
|
||||
}
|
||||
}
|
||||
|
||||
// loadPlatformPolicy reads the MDM-managed configuration from the
|
||||
// Windows registry under HKLM\Software\Policies\NetBird. Returns:
|
||||
// - (nil, nil) when the key is absent (device not MDM-enrolled for NetBird)
|
||||
// - (map, nil) with N entries when N managed values are set (N may be 0)
|
||||
// - (nil, err) on open / enumerate registry errors
|
||||
//
|
||||
// Per-value type coercion + skip-on-error is delegated to
|
||||
// readRegistryValue. Unknown value names are logged and skipped so a
|
||||
// malformed deployment does not block startup.
|
||||
func loadPlatformPolicy() (map[string]any, error) {
|
||||
k, err := registry.OpenKey(registry.LOCAL_MACHINE, policyRegistryPath, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
if errors.Is(err, registry.ErrNotExist) {
|
||||
// Not enrolled. Caller treats nil as "no MDM source present".
|
||||
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see LoadPolicy.
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("open %s: %w", policyRegistryPath, err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := k.Close(); closeErr != nil {
|
||||
log.Warnf("MDM close registry key %s: %v", policyRegistryPath, closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
names, err := k.ReadValueNames(-1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("enumerate values of %s: %w", policyRegistryPath, err)
|
||||
}
|
||||
|
||||
out := make(map[string]any, len(names))
|
||||
for _, name := range names {
|
||||
// Canonicalize the registry value name against the known MDM key
|
||||
// set so Policy.HasKey lookups (which use the canonical names)
|
||||
// succeed regardless of the casing used by the admin's ADMX or
|
||||
// `reg add` command.
|
||||
canonical, known := canonicalKey[strings.ToLower(name)]
|
||||
if !known {
|
||||
log.Warnf("MDM ignoring unknown registry value %s\\%s", policyRegistryPath, name)
|
||||
continue
|
||||
}
|
||||
readRegistryValue(k, name, canonical, out)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
129
client/mdm/ticker.go
Normal file
129
client/mdm/ticker.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package mdm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// DefaultReloadInterval is the production cadence at which the desktop daemon
|
||||
// re-reads the OS-native MDM policy. Picked to balance responsiveness against
|
||||
// registry/plist I/O overhead. Mobile builds use OS-side notifications
|
||||
// instead, hence anticipating the ticker mechanism entirely.
|
||||
const DefaultReloadInterval = 1 * time.Minute
|
||||
|
||||
// policyLoader is the indirection through which the ticker reads the
|
||||
// OS-native policy, both for the initial observation and on every tick.
|
||||
// Production points it at LoadPolicy; tests in this package override it to
|
||||
// feed a scripted sequence of policies without touching the real OS store.
|
||||
var policyLoader = LoadPolicy
|
||||
|
||||
// Ticker periodically re-reads the OS-native MDM policy via LoadPolicy and
|
||||
// invokes the onChange callback (supplied to Run) whenever the observed
|
||||
// Policy diverges from the last observation (added / removed / changed
|
||||
// keys). Launch with Run from a goroutine; cancel the supplied context
|
||||
// to stop.
|
||||
type Ticker struct {
|
||||
interval time.Duration
|
||||
prev *Policy
|
||||
}
|
||||
|
||||
// NewTicker constructs a Ticker that will re-read the OS-native policy
|
||||
// every reloadInterval once Run is called.
|
||||
// The initial snapshot is populated by calling policyLoader at
|
||||
// construction time so the first tick only fires
|
||||
// onChange when the policy actually changed since boot — without
|
||||
// this baseline the first tick would report every currently-managed
|
||||
// key as "added" and trigger a spurious engine restart.
|
||||
func NewTicker(reloadInterval time.Duration) *Ticker {
|
||||
return &Ticker{
|
||||
interval: reloadInterval,
|
||||
prev: policyLoader(),
|
||||
}
|
||||
}
|
||||
|
||||
// Run blocks until ctx is cancelled, polling the OS-native policy store at
|
||||
// the configured cadence and emitting log lines + onChange callback on
|
||||
// every observed diff. onChange must be non-nil.
|
||||
func (t *Ticker) Run(ctx context.Context, onChange func(prev, curr *Policy) error) {
|
||||
tk := time.NewTicker(t.interval)
|
||||
defer tk.Stop()
|
||||
log.Infof("MDM policy reload ticker started (interval=%s)", t.interval)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info("MDM policy reload ticker stopped")
|
||||
return
|
||||
case <-tk.C:
|
||||
curr := policyLoader()
|
||||
if policiesEqual(t.prev, curr) {
|
||||
continue
|
||||
}
|
||||
added, removed, changed := diffPolicies(t.prev, curr)
|
||||
log.Infof("MDM policy changed: added=%v removed=%v changed=%v",
|
||||
added, removed, changed)
|
||||
prev := t.prev
|
||||
if err := onChange(prev, curr); err != nil {
|
||||
log.Errorf("MDM policy change handler failed (retrying in 1 minute): %v", err)
|
||||
continue
|
||||
}
|
||||
t.prev = curr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// policiesEqual reports whether two Policy instances carry the same
|
||||
// managed key set with identical values. Nil and empty policies
|
||||
// compare equal; one-nil/one-non-empty compare not equal; otherwise
|
||||
// the underlying values maps are compared with reflect.DeepEqual.
|
||||
func policiesEqual(a, b *Policy) bool {
|
||||
if a.IsEmpty() && b.IsEmpty() {
|
||||
return true
|
||||
}
|
||||
if a == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
return reflect.DeepEqual(a.values, b.values)
|
||||
}
|
||||
|
||||
// diffPolicies returns the keys added in curr, removed from prev, and
|
||||
// whose values changed between prev and curr. Each slice is sorted
|
||||
// lexicographically for stable log output; value differences are
|
||||
// determined with reflect.DeepEqual.
|
||||
func diffPolicies(prev, curr *Policy) (added, removed, changed []string) {
|
||||
prevKVs := mapOf(prev)
|
||||
currKVs := mapOf(curr)
|
||||
for k := range currKVs {
|
||||
if _, ok := prevKVs[k]; !ok {
|
||||
added = append(added, k)
|
||||
} else if !reflect.DeepEqual(prevKVs[k], currKVs[k]) {
|
||||
changed = append(changed, k)
|
||||
}
|
||||
}
|
||||
for k := range prevKVs {
|
||||
if _, ok := currKVs[k]; !ok {
|
||||
removed = append(removed, k)
|
||||
}
|
||||
}
|
||||
sort.Strings(added)
|
||||
sort.Strings(removed)
|
||||
sort.Strings(changed)
|
||||
return added, removed, changed
|
||||
}
|
||||
|
||||
// mapOf returns a (possibly empty, never nil) copy of the underlying
|
||||
// values map of a Policy so callers outside this package can compare
|
||||
// keys/values across the type boundary. Returns an empty map on nil p.
|
||||
func mapOf(p *Policy) map[string]any {
|
||||
if p == nil {
|
||||
return map[string]any{}
|
||||
}
|
||||
out := make(map[string]any, len(p.values))
|
||||
for k, v := range p.values {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
100
client/mdm/ticker_test.go
Normal file
100
client/mdm/ticker_test.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package mdm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testReloadInterval for speeding up the ticker cadence under `go test`
|
||||
const testReloadInterval = 1 * time.Second
|
||||
|
||||
// withPolicyLoader overrides the package-level policyLoader for the duration
|
||||
// of the test so the ticker observes a scripted policy instead of the real
|
||||
// OS-native store. The original loader is restored on cleanup.
|
||||
func withPolicyLoader(t *testing.T, fn func() *Policy) {
|
||||
t.Helper()
|
||||
prev := policyLoader
|
||||
policyLoader = fn
|
||||
t.Cleanup(func() { policyLoader = prev })
|
||||
}
|
||||
|
||||
func TestTicker_FiresOnChangeWithDelta(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
current := NewPolicy(nil) // initial observation: empty (no enforcement)
|
||||
withPolicyLoader(t, func() *Policy {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return current
|
||||
})
|
||||
|
||||
type change struct{ prev, curr *Policy }
|
||||
changes := make(chan change, 1)
|
||||
tk := NewTicker(testReloadInterval)
|
||||
require.Equal(t, testReloadInterval, tk.interval)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
tk.Run(ctx, func(prev, curr *Policy) error {
|
||||
select {
|
||||
case changes <- change{prev, curr}:
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
// Stop Run and wait for it to exit before returning, so the policyLoader
|
||||
// restore in t.Cleanup can't race the ticker goroutine still reading it.
|
||||
defer func() { cancel(); <-done }()
|
||||
|
||||
// Flip the OS-observed policy from empty to one managed key. The next
|
||||
// tick must detect the diff and invoke onChange.
|
||||
mu.Lock()
|
||||
current = NewPolicy(map[string]any{KeyManagementURL: "https://mdm.example.com:443"})
|
||||
mu.Unlock()
|
||||
|
||||
select {
|
||||
case c := <-changes:
|
||||
assert.True(t, c.prev.IsEmpty(), "prev should be the initial empty policy")
|
||||
assert.True(t, c.curr.HasKey(KeyManagementURL), "curr should carry the newly-pushed managed key")
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("onChange not invoked within 5s; ticker should fire every 1s under test")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTicker_NoCallbackWhenPolicyUnchanged(t *testing.T) {
|
||||
withPolicyLoader(t, func() *Policy {
|
||||
return NewPolicy(map[string]any{KeyBlockInbound: true})
|
||||
})
|
||||
|
||||
fired := make(chan struct{}, 1)
|
||||
tk := NewTicker(testReloadInterval)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
tk.Run(ctx, func(_, _ *Policy) error {
|
||||
select {
|
||||
case fired <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
defer func() { cancel(); <-done }()
|
||||
|
||||
// Over ~2 ticks at the 1s test cadence the policy never changes, so the
|
||||
// diff guard must suppress the callback entirely.
|
||||
select {
|
||||
case <-fired:
|
||||
t.Fatal("onChange fired despite an unchanged policy")
|
||||
case <-time.After(2500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
@@ -1191,8 +1191,14 @@ type GetConfigResponse struct {
|
||||
DisableSSHAuth bool `protobuf:"varint,25,opt,name=disableSSHAuth,proto3" json:"disableSSHAuth,omitempty"`
|
||||
SshJWTCacheTTL int32 `protobuf:"varint,26,opt,name=sshJWTCacheTTL,proto3" json:"sshJWTCacheTTL,omitempty"`
|
||||
DisableIpv6 bool `protobuf:"varint,27,opt,name=disable_ipv6,json=disableIpv6,proto3" json:"disable_ipv6,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
// mDMManagedFields lists the names of configuration keys whose value is
|
||||
// currently enforced by an MDM policy. Names match mdm.Key* constants
|
||||
// (e.g. "managementURL", "disableClientRoutes"). UI/CLI clients should
|
||||
// render the corresponding inputs as read-only and display a "managed
|
||||
// by MDM" indicator.
|
||||
MDMManagedFields []string `protobuf:"bytes,28,rep,name=mDMManagedFields,proto3" json:"mDMManagedFields,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *GetConfigResponse) Reset() {
|
||||
@@ -1414,6 +1420,13 @@ func (x *GetConfigResponse) GetDisableIpv6() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *GetConfigResponse) GetMDMManagedFields() []string {
|
||||
if x != nil {
|
||||
return x.MDMManagedFields
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PeerState contains the latest state of a peer
|
||||
type PeerState struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
@@ -4961,6 +4974,55 @@ func (x *GetFeaturesResponse) GetDisableNetworks() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// MDMManagedFieldsViolation is attached as a gRPC error detail on a
|
||||
// FailedPrecondition status returned from SetConfig (and similar mutating
|
||||
// RPCs) when the caller tries to modify one or more MDM-enforced fields.
|
||||
// The fields list contains the offending key names; the entire request is
|
||||
// rejected (no partial apply).
|
||||
type MDMManagedFieldsViolation struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Fields []string `protobuf:"bytes,1,rep,name=fields,proto3" json:"fields,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *MDMManagedFieldsViolation) Reset() {
|
||||
*x = MDMManagedFieldsViolation{}
|
||||
mi := &file_daemon_proto_msgTypes[71]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *MDMManagedFieldsViolation) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*MDMManagedFieldsViolation) ProtoMessage() {}
|
||||
|
||||
func (x *MDMManagedFieldsViolation) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[71]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use MDMManagedFieldsViolation.ProtoReflect.Descriptor instead.
|
||||
func (*MDMManagedFieldsViolation) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{71}
|
||||
}
|
||||
|
||||
func (x *MDMManagedFieldsViolation) GetFields() []string {
|
||||
if x != nil {
|
||||
return x.Fields
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type TriggerUpdateRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
@@ -4969,7 +5031,7 @@ type TriggerUpdateRequest struct {
|
||||
|
||||
func (x *TriggerUpdateRequest) Reset() {
|
||||
*x = TriggerUpdateRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[71]
|
||||
mi := &file_daemon_proto_msgTypes[72]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -4981,7 +5043,7 @@ func (x *TriggerUpdateRequest) String() string {
|
||||
func (*TriggerUpdateRequest) ProtoMessage() {}
|
||||
|
||||
func (x *TriggerUpdateRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[71]
|
||||
mi := &file_daemon_proto_msgTypes[72]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -4994,7 +5056,7 @@ func (x *TriggerUpdateRequest) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use TriggerUpdateRequest.ProtoReflect.Descriptor instead.
|
||||
func (*TriggerUpdateRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{71}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{72}
|
||||
}
|
||||
|
||||
type TriggerUpdateResponse struct {
|
||||
@@ -5007,7 +5069,7 @@ type TriggerUpdateResponse struct {
|
||||
|
||||
func (x *TriggerUpdateResponse) Reset() {
|
||||
*x = TriggerUpdateResponse{}
|
||||
mi := &file_daemon_proto_msgTypes[72]
|
||||
mi := &file_daemon_proto_msgTypes[73]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5019,7 +5081,7 @@ func (x *TriggerUpdateResponse) String() string {
|
||||
func (*TriggerUpdateResponse) ProtoMessage() {}
|
||||
|
||||
func (x *TriggerUpdateResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[72]
|
||||
mi := &file_daemon_proto_msgTypes[73]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5032,7 +5094,7 @@ func (x *TriggerUpdateResponse) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use TriggerUpdateResponse.ProtoReflect.Descriptor instead.
|
||||
func (*TriggerUpdateResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{72}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{73}
|
||||
}
|
||||
|
||||
func (x *TriggerUpdateResponse) GetSuccess() bool {
|
||||
@@ -5060,7 +5122,7 @@ type GetPeerSSHHostKeyRequest struct {
|
||||
|
||||
func (x *GetPeerSSHHostKeyRequest) Reset() {
|
||||
*x = GetPeerSSHHostKeyRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[73]
|
||||
mi := &file_daemon_proto_msgTypes[74]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5072,7 +5134,7 @@ func (x *GetPeerSSHHostKeyRequest) String() string {
|
||||
func (*GetPeerSSHHostKeyRequest) ProtoMessage() {}
|
||||
|
||||
func (x *GetPeerSSHHostKeyRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[73]
|
||||
mi := &file_daemon_proto_msgTypes[74]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5085,7 +5147,7 @@ func (x *GetPeerSSHHostKeyRequest) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use GetPeerSSHHostKeyRequest.ProtoReflect.Descriptor instead.
|
||||
func (*GetPeerSSHHostKeyRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{73}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{74}
|
||||
}
|
||||
|
||||
func (x *GetPeerSSHHostKeyRequest) GetPeerAddress() string {
|
||||
@@ -5112,7 +5174,7 @@ type GetPeerSSHHostKeyResponse struct {
|
||||
|
||||
func (x *GetPeerSSHHostKeyResponse) Reset() {
|
||||
*x = GetPeerSSHHostKeyResponse{}
|
||||
mi := &file_daemon_proto_msgTypes[74]
|
||||
mi := &file_daemon_proto_msgTypes[75]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5124,7 +5186,7 @@ func (x *GetPeerSSHHostKeyResponse) String() string {
|
||||
func (*GetPeerSSHHostKeyResponse) ProtoMessage() {}
|
||||
|
||||
func (x *GetPeerSSHHostKeyResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[74]
|
||||
mi := &file_daemon_proto_msgTypes[75]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5137,7 +5199,7 @@ func (x *GetPeerSSHHostKeyResponse) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use GetPeerSSHHostKeyResponse.ProtoReflect.Descriptor instead.
|
||||
func (*GetPeerSSHHostKeyResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{74}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{75}
|
||||
}
|
||||
|
||||
func (x *GetPeerSSHHostKeyResponse) GetSshHostKey() []byte {
|
||||
@@ -5179,7 +5241,7 @@ type RequestJWTAuthRequest struct {
|
||||
|
||||
func (x *RequestJWTAuthRequest) Reset() {
|
||||
*x = RequestJWTAuthRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[75]
|
||||
mi := &file_daemon_proto_msgTypes[76]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5191,7 +5253,7 @@ func (x *RequestJWTAuthRequest) String() string {
|
||||
func (*RequestJWTAuthRequest) ProtoMessage() {}
|
||||
|
||||
func (x *RequestJWTAuthRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[75]
|
||||
mi := &file_daemon_proto_msgTypes[76]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5204,7 +5266,7 @@ func (x *RequestJWTAuthRequest) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use RequestJWTAuthRequest.ProtoReflect.Descriptor instead.
|
||||
func (*RequestJWTAuthRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{75}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{76}
|
||||
}
|
||||
|
||||
func (x *RequestJWTAuthRequest) GetHint() string {
|
||||
@@ -5237,7 +5299,7 @@ type RequestJWTAuthResponse struct {
|
||||
|
||||
func (x *RequestJWTAuthResponse) Reset() {
|
||||
*x = RequestJWTAuthResponse{}
|
||||
mi := &file_daemon_proto_msgTypes[76]
|
||||
mi := &file_daemon_proto_msgTypes[77]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5249,7 +5311,7 @@ func (x *RequestJWTAuthResponse) String() string {
|
||||
func (*RequestJWTAuthResponse) ProtoMessage() {}
|
||||
|
||||
func (x *RequestJWTAuthResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[76]
|
||||
mi := &file_daemon_proto_msgTypes[77]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5262,7 +5324,7 @@ func (x *RequestJWTAuthResponse) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use RequestJWTAuthResponse.ProtoReflect.Descriptor instead.
|
||||
func (*RequestJWTAuthResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{76}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{77}
|
||||
}
|
||||
|
||||
func (x *RequestJWTAuthResponse) GetVerificationURI() string {
|
||||
@@ -5327,7 +5389,7 @@ type WaitJWTTokenRequest struct {
|
||||
|
||||
func (x *WaitJWTTokenRequest) Reset() {
|
||||
*x = WaitJWTTokenRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[77]
|
||||
mi := &file_daemon_proto_msgTypes[78]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5339,7 +5401,7 @@ func (x *WaitJWTTokenRequest) String() string {
|
||||
func (*WaitJWTTokenRequest) ProtoMessage() {}
|
||||
|
||||
func (x *WaitJWTTokenRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[77]
|
||||
mi := &file_daemon_proto_msgTypes[78]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5352,7 +5414,7 @@ func (x *WaitJWTTokenRequest) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use WaitJWTTokenRequest.ProtoReflect.Descriptor instead.
|
||||
func (*WaitJWTTokenRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{77}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{78}
|
||||
}
|
||||
|
||||
func (x *WaitJWTTokenRequest) GetDeviceCode() string {
|
||||
@@ -5384,7 +5446,7 @@ type WaitJWTTokenResponse struct {
|
||||
|
||||
func (x *WaitJWTTokenResponse) Reset() {
|
||||
*x = WaitJWTTokenResponse{}
|
||||
mi := &file_daemon_proto_msgTypes[78]
|
||||
mi := &file_daemon_proto_msgTypes[79]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5396,7 +5458,7 @@ func (x *WaitJWTTokenResponse) String() string {
|
||||
func (*WaitJWTTokenResponse) ProtoMessage() {}
|
||||
|
||||
func (x *WaitJWTTokenResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[78]
|
||||
mi := &file_daemon_proto_msgTypes[79]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5409,7 +5471,7 @@ func (x *WaitJWTTokenResponse) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use WaitJWTTokenResponse.ProtoReflect.Descriptor instead.
|
||||
func (*WaitJWTTokenResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{78}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{79}
|
||||
}
|
||||
|
||||
func (x *WaitJWTTokenResponse) GetToken() string {
|
||||
@@ -5442,7 +5504,7 @@ type StartCPUProfileRequest struct {
|
||||
|
||||
func (x *StartCPUProfileRequest) Reset() {
|
||||
*x = StartCPUProfileRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[79]
|
||||
mi := &file_daemon_proto_msgTypes[80]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5454,7 +5516,7 @@ func (x *StartCPUProfileRequest) String() string {
|
||||
func (*StartCPUProfileRequest) ProtoMessage() {}
|
||||
|
||||
func (x *StartCPUProfileRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[79]
|
||||
mi := &file_daemon_proto_msgTypes[80]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5467,7 +5529,7 @@ func (x *StartCPUProfileRequest) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use StartCPUProfileRequest.ProtoReflect.Descriptor instead.
|
||||
func (*StartCPUProfileRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{79}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{80}
|
||||
}
|
||||
|
||||
// StartCPUProfileResponse confirms CPU profiling has started
|
||||
@@ -5479,7 +5541,7 @@ type StartCPUProfileResponse struct {
|
||||
|
||||
func (x *StartCPUProfileResponse) Reset() {
|
||||
*x = StartCPUProfileResponse{}
|
||||
mi := &file_daemon_proto_msgTypes[80]
|
||||
mi := &file_daemon_proto_msgTypes[81]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5491,7 +5553,7 @@ func (x *StartCPUProfileResponse) String() string {
|
||||
func (*StartCPUProfileResponse) ProtoMessage() {}
|
||||
|
||||
func (x *StartCPUProfileResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[80]
|
||||
mi := &file_daemon_proto_msgTypes[81]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5504,7 +5566,7 @@ func (x *StartCPUProfileResponse) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use StartCPUProfileResponse.ProtoReflect.Descriptor instead.
|
||||
func (*StartCPUProfileResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{80}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{81}
|
||||
}
|
||||
|
||||
// StopCPUProfileRequest for stopping CPU profiling
|
||||
@@ -5516,7 +5578,7 @@ type StopCPUProfileRequest struct {
|
||||
|
||||
func (x *StopCPUProfileRequest) Reset() {
|
||||
*x = StopCPUProfileRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[81]
|
||||
mi := &file_daemon_proto_msgTypes[82]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5528,7 +5590,7 @@ func (x *StopCPUProfileRequest) String() string {
|
||||
func (*StopCPUProfileRequest) ProtoMessage() {}
|
||||
|
||||
func (x *StopCPUProfileRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[81]
|
||||
mi := &file_daemon_proto_msgTypes[82]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5541,7 +5603,7 @@ func (x *StopCPUProfileRequest) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use StopCPUProfileRequest.ProtoReflect.Descriptor instead.
|
||||
func (*StopCPUProfileRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{81}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{82}
|
||||
}
|
||||
|
||||
// StopCPUProfileResponse confirms CPU profiling has stopped
|
||||
@@ -5553,7 +5615,7 @@ type StopCPUProfileResponse struct {
|
||||
|
||||
func (x *StopCPUProfileResponse) Reset() {
|
||||
*x = StopCPUProfileResponse{}
|
||||
mi := &file_daemon_proto_msgTypes[82]
|
||||
mi := &file_daemon_proto_msgTypes[83]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5565,7 +5627,7 @@ func (x *StopCPUProfileResponse) String() string {
|
||||
func (*StopCPUProfileResponse) ProtoMessage() {}
|
||||
|
||||
func (x *StopCPUProfileResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[82]
|
||||
mi := &file_daemon_proto_msgTypes[83]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5578,7 +5640,7 @@ func (x *StopCPUProfileResponse) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use StopCPUProfileResponse.ProtoReflect.Descriptor instead.
|
||||
func (*StopCPUProfileResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{82}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{83}
|
||||
}
|
||||
|
||||
type InstallerResultRequest struct {
|
||||
@@ -5589,7 +5651,7 @@ type InstallerResultRequest struct {
|
||||
|
||||
func (x *InstallerResultRequest) Reset() {
|
||||
*x = InstallerResultRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[83]
|
||||
mi := &file_daemon_proto_msgTypes[84]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5601,7 +5663,7 @@ func (x *InstallerResultRequest) String() string {
|
||||
func (*InstallerResultRequest) ProtoMessage() {}
|
||||
|
||||
func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[83]
|
||||
mi := &file_daemon_proto_msgTypes[84]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5614,7 +5676,7 @@ func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use InstallerResultRequest.ProtoReflect.Descriptor instead.
|
||||
func (*InstallerResultRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{83}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{84}
|
||||
}
|
||||
|
||||
type InstallerResultResponse struct {
|
||||
@@ -5627,7 +5689,7 @@ type InstallerResultResponse struct {
|
||||
|
||||
func (x *InstallerResultResponse) Reset() {
|
||||
*x = InstallerResultResponse{}
|
||||
mi := &file_daemon_proto_msgTypes[84]
|
||||
mi := &file_daemon_proto_msgTypes[85]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5639,7 +5701,7 @@ func (x *InstallerResultResponse) String() string {
|
||||
func (*InstallerResultResponse) ProtoMessage() {}
|
||||
|
||||
func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[84]
|
||||
mi := &file_daemon_proto_msgTypes[85]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5652,7 +5714,7 @@ func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use InstallerResultResponse.ProtoReflect.Descriptor instead.
|
||||
func (*InstallerResultResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{84}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{85}
|
||||
}
|
||||
|
||||
func (x *InstallerResultResponse) GetSuccess() bool {
|
||||
@@ -5685,7 +5747,7 @@ type ExposeServiceRequest struct {
|
||||
|
||||
func (x *ExposeServiceRequest) Reset() {
|
||||
*x = ExposeServiceRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[85]
|
||||
mi := &file_daemon_proto_msgTypes[86]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5697,7 +5759,7 @@ func (x *ExposeServiceRequest) String() string {
|
||||
func (*ExposeServiceRequest) ProtoMessage() {}
|
||||
|
||||
func (x *ExposeServiceRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[85]
|
||||
mi := &file_daemon_proto_msgTypes[86]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5710,7 +5772,7 @@ func (x *ExposeServiceRequest) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use ExposeServiceRequest.ProtoReflect.Descriptor instead.
|
||||
func (*ExposeServiceRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{85}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{86}
|
||||
}
|
||||
|
||||
func (x *ExposeServiceRequest) GetPort() uint32 {
|
||||
@@ -5781,7 +5843,7 @@ type ExposeServiceEvent struct {
|
||||
|
||||
func (x *ExposeServiceEvent) Reset() {
|
||||
*x = ExposeServiceEvent{}
|
||||
mi := &file_daemon_proto_msgTypes[86]
|
||||
mi := &file_daemon_proto_msgTypes[87]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5793,7 +5855,7 @@ func (x *ExposeServiceEvent) String() string {
|
||||
func (*ExposeServiceEvent) ProtoMessage() {}
|
||||
|
||||
func (x *ExposeServiceEvent) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[86]
|
||||
mi := &file_daemon_proto_msgTypes[87]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5806,7 +5868,7 @@ func (x *ExposeServiceEvent) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use ExposeServiceEvent.ProtoReflect.Descriptor instead.
|
||||
func (*ExposeServiceEvent) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{86}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{87}
|
||||
}
|
||||
|
||||
func (x *ExposeServiceEvent) GetEvent() isExposeServiceEvent_Event {
|
||||
@@ -5847,7 +5909,7 @@ type ExposeServiceReady struct {
|
||||
|
||||
func (x *ExposeServiceReady) Reset() {
|
||||
*x = ExposeServiceReady{}
|
||||
mi := &file_daemon_proto_msgTypes[87]
|
||||
mi := &file_daemon_proto_msgTypes[88]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5859,7 +5921,7 @@ func (x *ExposeServiceReady) String() string {
|
||||
func (*ExposeServiceReady) ProtoMessage() {}
|
||||
|
||||
func (x *ExposeServiceReady) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[87]
|
||||
mi := &file_daemon_proto_msgTypes[88]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5872,7 +5934,7 @@ func (x *ExposeServiceReady) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use ExposeServiceReady.ProtoReflect.Descriptor instead.
|
||||
func (*ExposeServiceReady) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{87}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{88}
|
||||
}
|
||||
|
||||
func (x *ExposeServiceReady) GetServiceName() string {
|
||||
@@ -5917,7 +5979,7 @@ type StartCaptureRequest struct {
|
||||
|
||||
func (x *StartCaptureRequest) Reset() {
|
||||
*x = StartCaptureRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[88]
|
||||
mi := &file_daemon_proto_msgTypes[89]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5929,7 +5991,7 @@ func (x *StartCaptureRequest) String() string {
|
||||
func (*StartCaptureRequest) ProtoMessage() {}
|
||||
|
||||
func (x *StartCaptureRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[88]
|
||||
mi := &file_daemon_proto_msgTypes[89]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5942,7 +6004,7 @@ func (x *StartCaptureRequest) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use StartCaptureRequest.ProtoReflect.Descriptor instead.
|
||||
func (*StartCaptureRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{88}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{89}
|
||||
}
|
||||
|
||||
func (x *StartCaptureRequest) GetTextOutput() bool {
|
||||
@@ -5996,7 +6058,7 @@ type CapturePacket struct {
|
||||
|
||||
func (x *CapturePacket) Reset() {
|
||||
*x = CapturePacket{}
|
||||
mi := &file_daemon_proto_msgTypes[89]
|
||||
mi := &file_daemon_proto_msgTypes[90]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -6008,7 +6070,7 @@ func (x *CapturePacket) String() string {
|
||||
func (*CapturePacket) ProtoMessage() {}
|
||||
|
||||
func (x *CapturePacket) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[89]
|
||||
mi := &file_daemon_proto_msgTypes[90]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -6021,7 +6083,7 @@ func (x *CapturePacket) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use CapturePacket.ProtoReflect.Descriptor instead.
|
||||
func (*CapturePacket) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{89}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{90}
|
||||
}
|
||||
|
||||
func (x *CapturePacket) GetData() []byte {
|
||||
@@ -6042,7 +6104,7 @@ type StartBundleCaptureRequest struct {
|
||||
|
||||
func (x *StartBundleCaptureRequest) Reset() {
|
||||
*x = StartBundleCaptureRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[90]
|
||||
mi := &file_daemon_proto_msgTypes[91]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -6054,7 +6116,7 @@ func (x *StartBundleCaptureRequest) String() string {
|
||||
func (*StartBundleCaptureRequest) ProtoMessage() {}
|
||||
|
||||
func (x *StartBundleCaptureRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[90]
|
||||
mi := &file_daemon_proto_msgTypes[91]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -6067,7 +6129,7 @@ func (x *StartBundleCaptureRequest) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use StartBundleCaptureRequest.ProtoReflect.Descriptor instead.
|
||||
func (*StartBundleCaptureRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{90}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{91}
|
||||
}
|
||||
|
||||
func (x *StartBundleCaptureRequest) GetTimeout() *durationpb.Duration {
|
||||
@@ -6085,7 +6147,7 @@ type StartBundleCaptureResponse struct {
|
||||
|
||||
func (x *StartBundleCaptureResponse) Reset() {
|
||||
*x = StartBundleCaptureResponse{}
|
||||
mi := &file_daemon_proto_msgTypes[91]
|
||||
mi := &file_daemon_proto_msgTypes[92]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -6097,7 +6159,7 @@ func (x *StartBundleCaptureResponse) String() string {
|
||||
func (*StartBundleCaptureResponse) ProtoMessage() {}
|
||||
|
||||
func (x *StartBundleCaptureResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[91]
|
||||
mi := &file_daemon_proto_msgTypes[92]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -6110,7 +6172,7 @@ func (x *StartBundleCaptureResponse) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use StartBundleCaptureResponse.ProtoReflect.Descriptor instead.
|
||||
func (*StartBundleCaptureResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{91}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{92}
|
||||
}
|
||||
|
||||
type StopBundleCaptureRequest struct {
|
||||
@@ -6121,7 +6183,7 @@ type StopBundleCaptureRequest struct {
|
||||
|
||||
func (x *StopBundleCaptureRequest) Reset() {
|
||||
*x = StopBundleCaptureRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[92]
|
||||
mi := &file_daemon_proto_msgTypes[93]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -6133,7 +6195,7 @@ func (x *StopBundleCaptureRequest) String() string {
|
||||
func (*StopBundleCaptureRequest) ProtoMessage() {}
|
||||
|
||||
func (x *StopBundleCaptureRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[92]
|
||||
mi := &file_daemon_proto_msgTypes[93]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -6146,7 +6208,7 @@ func (x *StopBundleCaptureRequest) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use StopBundleCaptureRequest.ProtoReflect.Descriptor instead.
|
||||
func (*StopBundleCaptureRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{92}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{93}
|
||||
}
|
||||
|
||||
type StopBundleCaptureResponse struct {
|
||||
@@ -6157,7 +6219,7 @@ type StopBundleCaptureResponse struct {
|
||||
|
||||
func (x *StopBundleCaptureResponse) Reset() {
|
||||
*x = StopBundleCaptureResponse{}
|
||||
mi := &file_daemon_proto_msgTypes[93]
|
||||
mi := &file_daemon_proto_msgTypes[94]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -6169,7 +6231,7 @@ func (x *StopBundleCaptureResponse) String() string {
|
||||
func (*StopBundleCaptureResponse) ProtoMessage() {}
|
||||
|
||||
func (x *StopBundleCaptureResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[93]
|
||||
mi := &file_daemon_proto_msgTypes[94]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -6182,7 +6244,7 @@ func (x *StopBundleCaptureResponse) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use StopBundleCaptureResponse.ProtoReflect.Descriptor instead.
|
||||
func (*StopBundleCaptureResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{93}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{94}
|
||||
}
|
||||
|
||||
type PortInfo_Range struct {
|
||||
@@ -6195,7 +6257,7 @@ type PortInfo_Range struct {
|
||||
|
||||
func (x *PortInfo_Range) Reset() {
|
||||
*x = PortInfo_Range{}
|
||||
mi := &file_daemon_proto_msgTypes[95]
|
||||
mi := &file_daemon_proto_msgTypes[96]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -6207,7 +6269,7 @@ func (x *PortInfo_Range) String() string {
|
||||
func (*PortInfo_Range) ProtoMessage() {}
|
||||
|
||||
func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[95]
|
||||
mi := &file_daemon_proto_msgTypes[96]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -6348,7 +6410,7 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\fDownResponse\"P\n" +
|
||||
"\x10GetConfigRequest\x12 \n" +
|
||||
"\vprofileName\x18\x01 \x01(\tR\vprofileName\x12\x1a\n" +
|
||||
"\busername\x18\x02 \x01(\tR\busername\"\xfe\b\n" +
|
||||
"\busername\x18\x02 \x01(\tR\busername\"\xaa\t\n" +
|
||||
"\x11GetConfigResponse\x12$\n" +
|
||||
"\rmanagementUrl\x18\x01 \x01(\tR\rmanagementUrl\x12\x1e\n" +
|
||||
"\n" +
|
||||
@@ -6380,7 +6442,8 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\x1denableSSHRemotePortForwarding\x18\x17 \x01(\bR\x1denableSSHRemotePortForwarding\x12&\n" +
|
||||
"\x0edisableSSHAuth\x18\x19 \x01(\bR\x0edisableSSHAuth\x12&\n" +
|
||||
"\x0esshJWTCacheTTL\x18\x1a \x01(\x05R\x0esshJWTCacheTTL\x12!\n" +
|
||||
"\fdisable_ipv6\x18\x1b \x01(\bR\vdisableIpv6\"\x92\x06\n" +
|
||||
"\fdisable_ipv6\x18\x1b \x01(\bR\vdisableIpv6\x12*\n" +
|
||||
"\x10mDMManagedFields\x18\x1c \x03(\tR\x10mDMManagedFields\"\x92\x06\n" +
|
||||
"\tPeerState\x12\x0e\n" +
|
||||
"\x02IP\x18\x01 \x01(\tR\x02IP\x12\x16\n" +
|
||||
"\x06pubKey\x18\x02 \x01(\tR\x06pubKey\x12\x1e\n" +
|
||||
@@ -6695,7 +6758,9 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\x13GetFeaturesResponse\x12)\n" +
|
||||
"\x10disable_profiles\x18\x01 \x01(\bR\x0fdisableProfiles\x126\n" +
|
||||
"\x17disable_update_settings\x18\x02 \x01(\bR\x15disableUpdateSettings\x12)\n" +
|
||||
"\x10disable_networks\x18\x03 \x01(\bR\x0fdisableNetworks\"\x16\n" +
|
||||
"\x10disable_networks\x18\x03 \x01(\bR\x0fdisableNetworks\"3\n" +
|
||||
"\x19MDMManagedFieldsViolation\x12\x16\n" +
|
||||
"\x06fields\x18\x01 \x03(\tR\x06fields\"\x16\n" +
|
||||
"\x14TriggerUpdateRequest\"M\n" +
|
||||
"\x15TriggerUpdateResponse\x12\x18\n" +
|
||||
"\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" +
|
||||
@@ -6851,7 +6916,7 @@ func file_daemon_proto_rawDescGZIP() []byte {
|
||||
}
|
||||
|
||||
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 4)
|
||||
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 97)
|
||||
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 98)
|
||||
var file_daemon_proto_goTypes = []any{
|
||||
(LogLevel)(0), // 0: daemon.LogLevel
|
||||
(ExposeProtocol)(0), // 1: daemon.ExposeProtocol
|
||||
@@ -6928,41 +6993,42 @@ var file_daemon_proto_goTypes = []any{
|
||||
(*LogoutResponse)(nil), // 72: daemon.LogoutResponse
|
||||
(*GetFeaturesRequest)(nil), // 73: daemon.GetFeaturesRequest
|
||||
(*GetFeaturesResponse)(nil), // 74: daemon.GetFeaturesResponse
|
||||
(*TriggerUpdateRequest)(nil), // 75: daemon.TriggerUpdateRequest
|
||||
(*TriggerUpdateResponse)(nil), // 76: daemon.TriggerUpdateResponse
|
||||
(*GetPeerSSHHostKeyRequest)(nil), // 77: daemon.GetPeerSSHHostKeyRequest
|
||||
(*GetPeerSSHHostKeyResponse)(nil), // 78: daemon.GetPeerSSHHostKeyResponse
|
||||
(*RequestJWTAuthRequest)(nil), // 79: daemon.RequestJWTAuthRequest
|
||||
(*RequestJWTAuthResponse)(nil), // 80: daemon.RequestJWTAuthResponse
|
||||
(*WaitJWTTokenRequest)(nil), // 81: daemon.WaitJWTTokenRequest
|
||||
(*WaitJWTTokenResponse)(nil), // 82: daemon.WaitJWTTokenResponse
|
||||
(*StartCPUProfileRequest)(nil), // 83: daemon.StartCPUProfileRequest
|
||||
(*StartCPUProfileResponse)(nil), // 84: daemon.StartCPUProfileResponse
|
||||
(*StopCPUProfileRequest)(nil), // 85: daemon.StopCPUProfileRequest
|
||||
(*StopCPUProfileResponse)(nil), // 86: daemon.StopCPUProfileResponse
|
||||
(*InstallerResultRequest)(nil), // 87: daemon.InstallerResultRequest
|
||||
(*InstallerResultResponse)(nil), // 88: daemon.InstallerResultResponse
|
||||
(*ExposeServiceRequest)(nil), // 89: daemon.ExposeServiceRequest
|
||||
(*ExposeServiceEvent)(nil), // 90: daemon.ExposeServiceEvent
|
||||
(*ExposeServiceReady)(nil), // 91: daemon.ExposeServiceReady
|
||||
(*StartCaptureRequest)(nil), // 92: daemon.StartCaptureRequest
|
||||
(*CapturePacket)(nil), // 93: daemon.CapturePacket
|
||||
(*StartBundleCaptureRequest)(nil), // 94: daemon.StartBundleCaptureRequest
|
||||
(*StartBundleCaptureResponse)(nil), // 95: daemon.StartBundleCaptureResponse
|
||||
(*StopBundleCaptureRequest)(nil), // 96: daemon.StopBundleCaptureRequest
|
||||
(*StopBundleCaptureResponse)(nil), // 97: daemon.StopBundleCaptureResponse
|
||||
nil, // 98: daemon.Network.ResolvedIPsEntry
|
||||
(*PortInfo_Range)(nil), // 99: daemon.PortInfo.Range
|
||||
nil, // 100: daemon.SystemEvent.MetadataEntry
|
||||
(*durationpb.Duration)(nil), // 101: google.protobuf.Duration
|
||||
(*timestamppb.Timestamp)(nil), // 102: google.protobuf.Timestamp
|
||||
(*MDMManagedFieldsViolation)(nil), // 75: daemon.MDMManagedFieldsViolation
|
||||
(*TriggerUpdateRequest)(nil), // 76: daemon.TriggerUpdateRequest
|
||||
(*TriggerUpdateResponse)(nil), // 77: daemon.TriggerUpdateResponse
|
||||
(*GetPeerSSHHostKeyRequest)(nil), // 78: daemon.GetPeerSSHHostKeyRequest
|
||||
(*GetPeerSSHHostKeyResponse)(nil), // 79: daemon.GetPeerSSHHostKeyResponse
|
||||
(*RequestJWTAuthRequest)(nil), // 80: daemon.RequestJWTAuthRequest
|
||||
(*RequestJWTAuthResponse)(nil), // 81: daemon.RequestJWTAuthResponse
|
||||
(*WaitJWTTokenRequest)(nil), // 82: daemon.WaitJWTTokenRequest
|
||||
(*WaitJWTTokenResponse)(nil), // 83: daemon.WaitJWTTokenResponse
|
||||
(*StartCPUProfileRequest)(nil), // 84: daemon.StartCPUProfileRequest
|
||||
(*StartCPUProfileResponse)(nil), // 85: daemon.StartCPUProfileResponse
|
||||
(*StopCPUProfileRequest)(nil), // 86: daemon.StopCPUProfileRequest
|
||||
(*StopCPUProfileResponse)(nil), // 87: daemon.StopCPUProfileResponse
|
||||
(*InstallerResultRequest)(nil), // 88: daemon.InstallerResultRequest
|
||||
(*InstallerResultResponse)(nil), // 89: daemon.InstallerResultResponse
|
||||
(*ExposeServiceRequest)(nil), // 90: daemon.ExposeServiceRequest
|
||||
(*ExposeServiceEvent)(nil), // 91: daemon.ExposeServiceEvent
|
||||
(*ExposeServiceReady)(nil), // 92: daemon.ExposeServiceReady
|
||||
(*StartCaptureRequest)(nil), // 93: daemon.StartCaptureRequest
|
||||
(*CapturePacket)(nil), // 94: daemon.CapturePacket
|
||||
(*StartBundleCaptureRequest)(nil), // 95: daemon.StartBundleCaptureRequest
|
||||
(*StartBundleCaptureResponse)(nil), // 96: daemon.StartBundleCaptureResponse
|
||||
(*StopBundleCaptureRequest)(nil), // 97: daemon.StopBundleCaptureRequest
|
||||
(*StopBundleCaptureResponse)(nil), // 98: daemon.StopBundleCaptureResponse
|
||||
nil, // 99: daemon.Network.ResolvedIPsEntry
|
||||
(*PortInfo_Range)(nil), // 100: daemon.PortInfo.Range
|
||||
nil, // 101: daemon.SystemEvent.MetadataEntry
|
||||
(*durationpb.Duration)(nil), // 102: google.protobuf.Duration
|
||||
(*timestamppb.Timestamp)(nil), // 103: google.protobuf.Timestamp
|
||||
}
|
||||
var file_daemon_proto_depIdxs = []int32{
|
||||
101, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||
102, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||
25, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
|
||||
102, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
|
||||
102, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
|
||||
101, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration
|
||||
103, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
|
||||
103, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
|
||||
102, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration
|
||||
23, // 5: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo
|
||||
20, // 6: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
|
||||
19, // 7: daemon.FullStatus.signalState:type_name -> daemon.SignalState
|
||||
@@ -6973,8 +7039,8 @@ var file_daemon_proto_depIdxs = []int32{
|
||||
55, // 12: daemon.FullStatus.events:type_name -> daemon.SystemEvent
|
||||
24, // 13: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState
|
||||
31, // 14: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
|
||||
98, // 15: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
|
||||
99, // 16: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
|
||||
99, // 15: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
|
||||
100, // 16: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
|
||||
32, // 17: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
|
||||
32, // 18: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
|
||||
33, // 19: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
|
||||
@@ -6985,15 +7051,15 @@ var file_daemon_proto_depIdxs = []int32{
|
||||
52, // 24: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
|
||||
2, // 25: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
|
||||
3, // 26: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
|
||||
102, // 27: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
|
||||
100, // 28: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
|
||||
103, // 27: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
|
||||
101, // 28: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
|
||||
55, // 29: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
|
||||
101, // 30: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||
102, // 30: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||
68, // 31: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
|
||||
1, // 32: daemon.ExposeServiceRequest.protocol:type_name -> daemon.ExposeProtocol
|
||||
91, // 33: daemon.ExposeServiceEvent.ready:type_name -> daemon.ExposeServiceReady
|
||||
101, // 34: daemon.StartCaptureRequest.duration:type_name -> google.protobuf.Duration
|
||||
101, // 35: daemon.StartBundleCaptureRequest.timeout:type_name -> google.protobuf.Duration
|
||||
92, // 33: daemon.ExposeServiceEvent.ready:type_name -> daemon.ExposeServiceReady
|
||||
102, // 34: daemon.StartCaptureRequest.duration:type_name -> google.protobuf.Duration
|
||||
102, // 35: daemon.StartBundleCaptureRequest.timeout:type_name -> google.protobuf.Duration
|
||||
30, // 36: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
|
||||
5, // 37: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
|
||||
7, // 38: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
|
||||
@@ -7013,9 +7079,9 @@ var file_daemon_proto_depIdxs = []int32{
|
||||
46, // 52: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest
|
||||
48, // 53: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest
|
||||
51, // 54: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest
|
||||
92, // 55: daemon.DaemonService.StartCapture:input_type -> daemon.StartCaptureRequest
|
||||
94, // 56: daemon.DaemonService.StartBundleCapture:input_type -> daemon.StartBundleCaptureRequest
|
||||
96, // 57: daemon.DaemonService.StopBundleCapture:input_type -> daemon.StopBundleCaptureRequest
|
||||
93, // 55: daemon.DaemonService.StartCapture:input_type -> daemon.StartCaptureRequest
|
||||
95, // 56: daemon.DaemonService.StartBundleCapture:input_type -> daemon.StartBundleCaptureRequest
|
||||
97, // 57: daemon.DaemonService.StopBundleCapture:input_type -> daemon.StopBundleCaptureRequest
|
||||
54, // 58: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest
|
||||
56, // 59: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest
|
||||
58, // 60: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest
|
||||
@@ -7026,14 +7092,14 @@ var file_daemon_proto_depIdxs = []int32{
|
||||
69, // 65: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest
|
||||
71, // 66: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest
|
||||
73, // 67: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest
|
||||
75, // 68: daemon.DaemonService.TriggerUpdate:input_type -> daemon.TriggerUpdateRequest
|
||||
77, // 69: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest
|
||||
79, // 70: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
|
||||
81, // 71: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
|
||||
83, // 72: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest
|
||||
85, // 73: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest
|
||||
87, // 74: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
|
||||
89, // 75: daemon.DaemonService.ExposeService:input_type -> daemon.ExposeServiceRequest
|
||||
76, // 68: daemon.DaemonService.TriggerUpdate:input_type -> daemon.TriggerUpdateRequest
|
||||
78, // 69: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest
|
||||
80, // 70: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
|
||||
82, // 71: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
|
||||
84, // 72: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest
|
||||
86, // 73: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest
|
||||
88, // 74: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
|
||||
90, // 75: daemon.DaemonService.ExposeService:input_type -> daemon.ExposeServiceRequest
|
||||
6, // 76: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
|
||||
8, // 77: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
|
||||
10, // 78: daemon.DaemonService.Up:output_type -> daemon.UpResponse
|
||||
@@ -7052,9 +7118,9 @@ var file_daemon_proto_depIdxs = []int32{
|
||||
47, // 91: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
|
||||
49, // 92: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
|
||||
53, // 93: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
|
||||
93, // 94: daemon.DaemonService.StartCapture:output_type -> daemon.CapturePacket
|
||||
95, // 95: daemon.DaemonService.StartBundleCapture:output_type -> daemon.StartBundleCaptureResponse
|
||||
97, // 96: daemon.DaemonService.StopBundleCapture:output_type -> daemon.StopBundleCaptureResponse
|
||||
94, // 94: daemon.DaemonService.StartCapture:output_type -> daemon.CapturePacket
|
||||
96, // 95: daemon.DaemonService.StartBundleCapture:output_type -> daemon.StartBundleCaptureResponse
|
||||
98, // 96: daemon.DaemonService.StopBundleCapture:output_type -> daemon.StopBundleCaptureResponse
|
||||
55, // 97: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
|
||||
57, // 98: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
|
||||
59, // 99: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
|
||||
@@ -7065,14 +7131,14 @@ var file_daemon_proto_depIdxs = []int32{
|
||||
70, // 104: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
|
||||
72, // 105: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
|
||||
74, // 106: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
|
||||
76, // 107: daemon.DaemonService.TriggerUpdate:output_type -> daemon.TriggerUpdateResponse
|
||||
78, // 108: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
|
||||
80, // 109: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
|
||||
82, // 110: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
|
||||
84, // 111: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse
|
||||
86, // 112: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse
|
||||
88, // 113: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
|
||||
90, // 114: daemon.DaemonService.ExposeService:output_type -> daemon.ExposeServiceEvent
|
||||
77, // 107: daemon.DaemonService.TriggerUpdate:output_type -> daemon.TriggerUpdateResponse
|
||||
79, // 108: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
|
||||
81, // 109: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
|
||||
83, // 110: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
|
||||
85, // 111: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse
|
||||
87, // 112: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse
|
||||
89, // 113: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
|
||||
91, // 114: daemon.DaemonService.ExposeService:output_type -> daemon.ExposeServiceEvent
|
||||
76, // [76:115] is the sub-list for method output_type
|
||||
37, // [37:76] is the sub-list for method input_type
|
||||
37, // [37:37] is the sub-list for extension type_name
|
||||
@@ -7097,8 +7163,8 @@ func file_daemon_proto_init() {
|
||||
file_daemon_proto_msgTypes[54].OneofWrappers = []any{}
|
||||
file_daemon_proto_msgTypes[56].OneofWrappers = []any{}
|
||||
file_daemon_proto_msgTypes[67].OneofWrappers = []any{}
|
||||
file_daemon_proto_msgTypes[75].OneofWrappers = []any{}
|
||||
file_daemon_proto_msgTypes[86].OneofWrappers = []any{
|
||||
file_daemon_proto_msgTypes[76].OneofWrappers = []any{}
|
||||
file_daemon_proto_msgTypes[87].OneofWrappers = []any{
|
||||
(*ExposeServiceEvent_Ready)(nil),
|
||||
}
|
||||
type x struct{}
|
||||
@@ -7107,7 +7173,7 @@ func file_daemon_proto_init() {
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)),
|
||||
NumEnums: 4,
|
||||
NumMessages: 97,
|
||||
NumMessages: 98,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
|
||||
@@ -314,6 +314,13 @@ message GetConfigResponse {
|
||||
int32 sshJWTCacheTTL = 26;
|
||||
|
||||
bool disable_ipv6 = 27;
|
||||
|
||||
// mDMManagedFields lists the names of configuration keys whose value is
|
||||
// currently enforced by an MDM policy. Names match mdm.Key* constants
|
||||
// (e.g. "managementURL", "disableClientRoutes"). UI/CLI clients should
|
||||
// render the corresponding inputs as read-only and display a "managed
|
||||
// by MDM" indicator.
|
||||
repeated string mDMManagedFields = 28;
|
||||
}
|
||||
|
||||
// PeerState contains the latest state of a peer
|
||||
@@ -733,6 +740,15 @@ message GetFeaturesResponse{
|
||||
bool disable_networks = 3;
|
||||
}
|
||||
|
||||
// MDMManagedFieldsViolation is attached as a gRPC error detail on a
|
||||
// FailedPrecondition status returned from SetConfig (and similar mutating
|
||||
// RPCs) when the caller tries to modify one or more MDM-enforced fields.
|
||||
// The fields list contains the offending key names; the entire request is
|
||||
// rejected (no partial apply).
|
||||
message MDMManagedFieldsViolation {
|
||||
repeated string fields = 1;
|
||||
}
|
||||
|
||||
message TriggerUpdateRequest {}
|
||||
|
||||
message TriggerUpdateResponse {
|
||||
|
||||
419
client/server/mdm.go
Normal file
419
client/server/mdm.go
Normal file
@@ -0,0 +1,419 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// preSharedKeyRedactedSentinel is the value GetConfig returns in place
|
||||
// of an actual PSK, so a UI that round-trips the field back to the
|
||||
// daemon (via SetConfig / Login) can be distinguished from a deliberate
|
||||
// override. Any incoming PSK that equals this sentinel is treated as
|
||||
// a no-op echo, never as a conflict with the policy.
|
||||
const preSharedKeyRedactedSentinel = "**********"
|
||||
|
||||
// loadMDMPolicy is the indirection used by server handlers to read the
|
||||
// active MDM policy. Tests override this to inject a fake policy.
|
||||
var loadMDMPolicy = mdm.LoadPolicy
|
||||
|
||||
// conflictCheck is a value-aware comparison between a single field in
|
||||
// the incoming request and the corresponding MDM-enforced value. It
|
||||
// runs only when the field was actually set in the request (presence
|
||||
// already filtered upstream); ok=true reports the policy value, ok=false
|
||||
// means the policy is silent on the key — both are treated as conflicts
|
||||
// to be safe (an MDM key declared as managed must hold a value).
|
||||
type conflictCheck struct {
|
||||
key string
|
||||
check func(*mdm.Policy) (match bool)
|
||||
}
|
||||
|
||||
// onMDMPolicyChange is invoked by the MDM reload ticker every time the
|
||||
// OS-native managed-config store reports a diff vs the last observation.
|
||||
//
|
||||
// Restart sequence:
|
||||
// 1. Cancel the active engine context (terminates connectWithRetryRuns).
|
||||
// 2. Wait briefly for that goroutine to exit (giveUpChan is closed on exit).
|
||||
// 3. Re-resolve Config from disk + MDM policy (Config.apply re-runs
|
||||
// applyMDMPolicy with the freshly loaded Policy).
|
||||
// 4. Spawn a fresh connectWithRetryRuns with the new context and config.
|
||||
// 5. Broadcast a SystemEvent so any GUI / CLI subscriber (SubscribeEvents
|
||||
// RPC) can refresh its cached config view without polling.
|
||||
//
|
||||
// The callback runs in the ticker's own goroutine. Ticker has already
|
||||
// logged the per-key diff before invoking this hook.
|
||||
func (s *Server) onMDMPolicyChange(_, _ *mdm.Policy) error {
|
||||
log.Warn("MDM policy changed; restarting engine to apply new configuration")
|
||||
|
||||
// Hold s.mutex for the entire restart sequence (cancel + quiescence
|
||||
// wait + re-spawn). Any concurrent Up/Down/Status arriving while
|
||||
// MDM is restarting blocks on the Lock until we are done — they
|
||||
// then observe the post-restart state coherently. This is safe
|
||||
// because the connectWithRetryRuns goroutine no longer acquires
|
||||
// s.mutex in its defer (intent vs. goroutine-alive concerns are
|
||||
// fully separated; see the connectionGoroutineRunning helper).
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if !s.clientRunning {
|
||||
// The client is not running, so there's no engine to restart.
|
||||
return nil
|
||||
}
|
||||
if s.actCancel != nil {
|
||||
s.actCancel()
|
||||
}
|
||||
|
||||
// Wait for previous connectWithRetryRuns to exit so we don't end up
|
||||
// with two goroutines fighting over the same status recorder + engine.
|
||||
// The teardown engages a fan-out of engine goroutines (peer workers,
|
||||
// signal handler, route manager, ...). close(clientGiveUpChan)
|
||||
// happens in the function-scope defer of connectWithRetryRuns, on
|
||||
// every exit path (ctx cancel, backoff exhausted, panic) — see the
|
||||
// defer in server.go.
|
||||
if s.clientGiveUpChan != nil {
|
||||
select {
|
||||
case <-s.clientGiveUpChan:
|
||||
case <-time.After(10 * time.Second):
|
||||
return fmt.Errorf("failed to restart the engine due to timeout")
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.restartEngineForMDMLocked(); err != nil {
|
||||
log.Errorf("MDM restart failed: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// publishConfigChangedEvent has already fired inside
|
||||
// restartEngineForMDMLocked with source="mdm". Emit an MDM-specific
|
||||
// user-visible toast so the operator knows their IT policy was
|
||||
// applied (UserMessage != "" triggers the GUI notifier).
|
||||
s.statusRecorder.PublishEvent(
|
||||
proto.SystemEvent_INFO,
|
||||
proto.SystemEvent_SYSTEM,
|
||||
"MDM policy applied",
|
||||
"NetBird configuration was updated by your IT policy.",
|
||||
map[string]string{"source": "mdm", "type": "policy_applied"},
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// publishConfigChangedEvent broadcasts a SystemEvent informing any active
|
||||
// SubscribeEvents subscriber (typically the GUI tray) that the daemon's
|
||||
// effective Config has been replaced and any cached client-side view
|
||||
// should be refreshed. Callers pass a stable `source` label so the GUI
|
||||
// can distinguish a startup spawn from a user-triggered Up or an
|
||||
// MDM-driven restart. Reusing the SYSTEM category keeps the proto enum
|
||||
// stable; metadata.type="config_changed" routes to the GUI's refresh
|
||||
// handler. UserMessage is left empty so the system tray does not toast
|
||||
// for every internal restart; the MDM path emits a separate
|
||||
// "policy_applied" event (with UserMessage) for that purpose.
|
||||
func (s *Server) publishConfigChangedEvent(source string) {
|
||||
if s.statusRecorder == nil {
|
||||
return
|
||||
}
|
||||
s.statusRecorder.PublishEvent(
|
||||
proto.SystemEvent_INFO,
|
||||
proto.SystemEvent_SYSTEM,
|
||||
fmt.Sprintf("daemon config changed (source=%s)", source),
|
||||
"",
|
||||
map[string]string{
|
||||
"source": source,
|
||||
"type": "config_changed",
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// restartEngineForMDMLocked re-resolves the active profile config
|
||||
// (re-running applyMDMPolicy via Config.apply) and re-spawns
|
||||
// connectWithRetryRuns. Mirrors the tail of Server.Start so a runtime
|
||||
// MDM change behaves identically to a fresh boot under the new policy.
|
||||
//
|
||||
// MUST be called with s.mutex held — onMDMPolicyChange holds the lock
|
||||
// for the entire restart sequence (cancel + quiescence wait + re-spawn)
|
||||
// so concurrent Up/Down/Status RPCs observe a coherent post-restart
|
||||
// state.
|
||||
func (s *Server) restartEngineForMDMLocked() error {
|
||||
activeProf, err := s.profileManager.GetActiveProfileState()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get active profile state: %w", err)
|
||||
}
|
||||
config, _, err := s.getConfig(activeProf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get active profile config: %w", err)
|
||||
}
|
||||
|
||||
s.config = config
|
||||
s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String())
|
||||
s.statusRecorder.UpdateRosenpass(config.RosenpassEnabled, config.RosenpassPermissive)
|
||||
s.statusRecorder.UpdateLazyConnection(config.LazyConnectionEnabled)
|
||||
|
||||
ctx, cancel := context.WithCancel(s.rootCtx)
|
||||
s.actCancel = cancel
|
||||
s.clientRunning = true
|
||||
s.clientRunningChan = make(chan struct{})
|
||||
s.clientGiveUpChan = make(chan struct{})
|
||||
log.Info("MDM restart: spawning connectWithRetryRuns with re-resolved config")
|
||||
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
|
||||
s.publishConfigChangedEvent("mdm")
|
||||
return nil
|
||||
}
|
||||
|
||||
// conflictBool builds a conflictCheck for a boolean MDM key. If p is nil
|
||||
// the field is treated as matching (no override requested); otherwise the
|
||||
// check returns true only when the policy contains the key and its
|
||||
// boolean value equals *p.
|
||||
func conflictBool(key string, p *bool) conflictCheck {
|
||||
return conflictCheck{
|
||||
key: key,
|
||||
check: func(pol *mdm.Policy) bool {
|
||||
if p == nil {
|
||||
return true // absent → match by definition
|
||||
}
|
||||
want, ok := pol.GetBool(key)
|
||||
return ok && want == *p
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// conflictString builds a conflictCheck for a string MDM key. An empty
|
||||
// `got` is treated as "field not set" (no override requested); otherwise
|
||||
// the check returns true only when the policy contains the key and its
|
||||
// value equals got.
|
||||
func conflictString(key, got string) conflictCheck {
|
||||
return conflictCheck{
|
||||
key: key,
|
||||
check: func(pol *mdm.Policy) bool {
|
||||
if got == "" {
|
||||
return true
|
||||
}
|
||||
want, ok := pol.GetString(key)
|
||||
return ok && want == got
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// conflictInt64 builds a conflictCheck for an integer MDM key. If p is
|
||||
// nil the field is treated as matching; otherwise the check returns
|
||||
// true only when the policy contains the key and its int value equals *p.
|
||||
func conflictInt64(key string, p *int64) conflictCheck {
|
||||
return conflictCheck{
|
||||
key: key,
|
||||
check: func(pol *mdm.Policy) bool {
|
||||
if p == nil {
|
||||
return true
|
||||
}
|
||||
want, ok := pol.GetInt(key)
|
||||
return ok && want == *p
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// resolveConflicts walks the per-field checks against the active MDM
|
||||
// policy and returns the names of keys whose requested value diverges
|
||||
// from the policy-enforced value. Keys not present in the policy are
|
||||
// skipped silently (the gate fires only for keys the admin has
|
||||
// actually pushed). Returns nil for an empty policy.
|
||||
func resolveConflicts(policy *mdm.Policy, checks []conflictCheck) []string {
|
||||
if policy.IsEmpty() {
|
||||
return nil
|
||||
}
|
||||
var conflicts []string
|
||||
for _, c := range checks {
|
||||
if !policy.HasKey(c.key) {
|
||||
continue
|
||||
}
|
||||
if !c.check(policy) {
|
||||
conflicts = append(conflicts, c.key)
|
||||
}
|
||||
}
|
||||
return conflicts
|
||||
}
|
||||
|
||||
// mdmManagedFieldConflicts returns the names of MDM-managed keys whose
|
||||
// requested value in the SetConfigRequest differs from the MDM-enforced
|
||||
// value. A field set to the same value the policy already enforces is
|
||||
// treated as a no-op echo (the GUI tray sends a full Config snapshot on
|
||||
// every toggle, so most fields in a typical request match the policy
|
||||
// exactly and must NOT be flagged as conflicts). The redacted PSK
|
||||
// sentinel ("**********") returned by GetConfig is recognised and
|
||||
// treated as no-op so the UI can safely round-trip it.
|
||||
func mdmManagedFieldConflicts(msg *proto.SetConfigRequest, policy *mdm.Policy) []string {
|
||||
if msg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// PSK round-trip echo: collapse the sentinel to empty so the
|
||||
// shared check treats it as "field not set".
|
||||
pskGot := ""
|
||||
if msg.OptionalPreSharedKey != nil && *msg.OptionalPreSharedKey != preSharedKeyRedactedSentinel {
|
||||
pskGot = *msg.OptionalPreSharedKey
|
||||
}
|
||||
|
||||
return resolveConflicts(policy, []conflictCheck{
|
||||
conflictString(mdm.KeyManagementURL, msg.ManagementUrl),
|
||||
conflictString(mdm.KeyPreSharedKey, pskGot),
|
||||
conflictBool(mdm.KeyRosenpassEnabled, msg.RosenpassEnabled),
|
||||
conflictBool(mdm.KeyRosenpassPermissive, msg.RosenpassPermissive),
|
||||
conflictBool(mdm.KeyDisableAutoConnect, msg.DisableAutoConnect),
|
||||
conflictBool(mdm.KeyAllowServerSSH, msg.ServerSSHAllowed),
|
||||
conflictBool(mdm.KeyDisableClientRoutes, msg.DisableClientRoutes),
|
||||
conflictBool(mdm.KeyDisableServerRoutes, msg.DisableServerRoutes),
|
||||
conflictBool(mdm.KeyBlockInbound, msg.BlockInbound),
|
||||
conflictInt64(mdm.KeyWireguardPort, msg.WireguardPort),
|
||||
})
|
||||
}
|
||||
|
||||
// setConfigRequestHasConfigOverrides reports whether the SetConfigRequest
|
||||
// carries ANY field that would actually mutate the persisted config.
|
||||
// The CLI builds a SetConfigRequest unconditionally on every
|
||||
// `netbird up` (see setupSetConfigReq in cmd/up.go) — a plain
|
||||
// `netbird up` produces a request with every field at its zero value;
|
||||
// the gate must skip such no-op invocations or it would always fire
|
||||
// even when the user did not pass any --flag. Returns false on a nil
|
||||
// msg; true when any management/admin URL, PSK, DNS/NAT list+clean
|
||||
// flag, interface/port/MTU, or any optional bool/duration field is set.
|
||||
func setConfigRequestHasConfigOverrides(msg *proto.SetConfigRequest) bool {
|
||||
if msg == nil {
|
||||
return false
|
||||
}
|
||||
return msg.ManagementUrl != "" ||
|
||||
msg.AdminURL != "" ||
|
||||
msg.OptionalPreSharedKey != nil ||
|
||||
len(msg.CustomDNSAddress) > 0 ||
|
||||
len(msg.NatExternalIPs) > 0 || msg.CleanNATExternalIPs ||
|
||||
len(msg.ExtraIFaceBlacklist) > 0 ||
|
||||
len(msg.DnsLabels) > 0 || msg.CleanDNSLabels ||
|
||||
msg.DnsRouteInterval != nil ||
|
||||
msg.RosenpassEnabled != nil ||
|
||||
msg.RosenpassPermissive != nil ||
|
||||
msg.InterfaceName != nil ||
|
||||
msg.WireguardPort != nil ||
|
||||
msg.Mtu != nil ||
|
||||
msg.DisableAutoConnect != nil ||
|
||||
msg.ServerSSHAllowed != nil ||
|
||||
msg.NetworkMonitor != nil ||
|
||||
msg.DisableClientRoutes != nil ||
|
||||
msg.DisableServerRoutes != nil ||
|
||||
msg.DisableDns != nil ||
|
||||
msg.DisableFirewall != nil ||
|
||||
msg.BlockLanAccess != nil ||
|
||||
msg.DisableNotifications != nil ||
|
||||
msg.LazyConnectionEnabled != nil ||
|
||||
msg.BlockInbound != nil ||
|
||||
msg.DisableIpv6 != nil ||
|
||||
msg.EnableSSHRoot != nil ||
|
||||
msg.EnableSSHSFTP != nil ||
|
||||
msg.EnableSSHLocalPortForwarding != nil ||
|
||||
msg.EnableSSHRemotePortForwarding != nil ||
|
||||
msg.DisableSSHAuth != nil ||
|
||||
msg.SshJWTCacheTTL != nil
|
||||
}
|
||||
|
||||
// loginRequestHasConfigOverrides reports whether the LoginRequest
|
||||
// carries ANY field that would mutate persisted daemon configuration
|
||||
// (as opposed to pure-auth fields like setupKey, hostname, hint,
|
||||
// profileName, username). Used by the Login handler to decide whether
|
||||
// the `--disable-update-settings` / MDM gates must run: a re-auth that
|
||||
// changes nothing about the configuration is always allowed.
|
||||
func loginRequestHasConfigOverrides(msg *proto.LoginRequest) bool {
|
||||
if msg == nil {
|
||||
return false
|
||||
}
|
||||
return msg.ManagementUrl != "" ||
|
||||
msg.AdminURL != "" ||
|
||||
msg.PreSharedKey != "" || //nolint:staticcheck // SA1019: legacy proto field still accepted by Login
|
||||
msg.OptionalPreSharedKey != nil ||
|
||||
len(msg.CustomDNSAddress) > 0 ||
|
||||
len(msg.NatExternalIPs) > 0 || msg.CleanNATExternalIPs ||
|
||||
msg.RosenpassEnabled != nil ||
|
||||
msg.InterfaceName != nil ||
|
||||
msg.WireguardPort != nil ||
|
||||
msg.DisableAutoConnect != nil ||
|
||||
msg.ServerSSHAllowed != nil ||
|
||||
msg.RosenpassPermissive != nil ||
|
||||
len(msg.ExtraIFaceBlacklist) > 0 ||
|
||||
msg.NetworkMonitor != nil ||
|
||||
msg.DnsRouteInterval != nil ||
|
||||
msg.DisableClientRoutes != nil ||
|
||||
msg.DisableServerRoutes != nil ||
|
||||
msg.DisableDns != nil ||
|
||||
msg.DisableFirewall != nil ||
|
||||
msg.BlockLanAccess != nil ||
|
||||
msg.DisableNotifications != nil ||
|
||||
len(msg.DnsLabels) > 0 || msg.CleanDNSLabels ||
|
||||
msg.LazyConnectionEnabled != nil ||
|
||||
msg.BlockInbound != nil
|
||||
}
|
||||
|
||||
// loginRequestMDMConflicts mirrors mdmManagedFieldConflicts but for the
|
||||
// LoginRequest surface. Same value-aware semantics: a field set to the
|
||||
// MDM-enforced value is a no-op echo, not a conflict; only a divergent
|
||||
// value is flagged. PSK has two proto fields — PreSharedKey (deprecated)
|
||||
// and OptionalPreSharedKey (current); either route trips the gate if it
|
||||
// diverges from the MDM-enforced PSK. OptionalPreSharedKey wins when
|
||||
// both are set; the redaction sentinel ("**********") is accepted as
|
||||
// a no-op echo.
|
||||
func loginRequestMDMConflicts(msg *proto.LoginRequest, policy *mdm.Policy) []string {
|
||||
if msg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Collapse the two PSK fields + the redaction sentinel down to a
|
||||
// single "got" string the shared check can compare against the
|
||||
// policy: OptionalPreSharedKey wins if set; PreSharedKey (deprecated)
|
||||
// is the fallback; sentinel echo is treated as "field not set".
|
||||
pskGot := ""
|
||||
if msg.OptionalPreSharedKey != nil {
|
||||
pskGot = *msg.OptionalPreSharedKey
|
||||
} else if msg.PreSharedKey != "" { //nolint:staticcheck // SA1019: legacy proto field still accepted by Login
|
||||
pskGot = msg.PreSharedKey //nolint:staticcheck // SA1019
|
||||
}
|
||||
if pskGot == preSharedKeyRedactedSentinel {
|
||||
pskGot = ""
|
||||
}
|
||||
|
||||
return resolveConflicts(policy, []conflictCheck{
|
||||
conflictString(mdm.KeyManagementURL, msg.ManagementUrl),
|
||||
conflictString(mdm.KeyPreSharedKey, pskGot),
|
||||
conflictBool(mdm.KeyRosenpassEnabled, msg.RosenpassEnabled),
|
||||
conflictBool(mdm.KeyRosenpassPermissive, msg.RosenpassPermissive),
|
||||
conflictBool(mdm.KeyDisableAutoConnect, msg.DisableAutoConnect),
|
||||
conflictBool(mdm.KeyAllowServerSSH, msg.ServerSSHAllowed),
|
||||
conflictBool(mdm.KeyDisableClientRoutes, msg.DisableClientRoutes),
|
||||
conflictBool(mdm.KeyDisableServerRoutes, msg.DisableServerRoutes),
|
||||
conflictBool(mdm.KeyBlockInbound, msg.BlockInbound),
|
||||
conflictInt64(mdm.KeyWireguardPort, msg.WireguardPort),
|
||||
})
|
||||
}
|
||||
|
||||
// rejectMDMManagedFieldConflicts returns a FailedPrecondition gRPC error
|
||||
// with an MDMManagedFieldsViolation detail when any of the requested
|
||||
// fields tries to change an MDM-enforced value to something else, and
|
||||
// nil otherwise. The whole request is rejected on any conflict; non-
|
||||
// conflicting fields in the same request are not applied either (no
|
||||
// partial apply).
|
||||
func rejectMDMManagedFieldConflicts(conflicts []string) error {
|
||||
if len(conflicts) == 0 {
|
||||
return nil
|
||||
}
|
||||
log.Warnf("MDM rejected request: tried to modify %d managed key(s): %v",
|
||||
len(conflicts), conflicts)
|
||||
st := gstatus.New(
|
||||
codes.FailedPrecondition,
|
||||
fmt.Sprintf("fields managed by MDM cannot be modified: %v", conflicts),
|
||||
)
|
||||
detailed, err := st.WithDetails(&proto.MDMManagedFieldsViolation{Fields: conflicts})
|
||||
if err != nil {
|
||||
// Detail attachment is best-effort; fall back to the plain status
|
||||
// so the caller still gets a usable FailedPrecondition.
|
||||
return st.Err()
|
||||
}
|
||||
return detailed.Err()
|
||||
}
|
||||
@@ -30,7 +30,7 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.networksDisabled {
|
||||
if s.checkNetworksDisabled() {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
|
||||
}
|
||||
|
||||
@@ -143,7 +143,7 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.networksDisabled {
|
||||
if s.checkNetworksDisabled() {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
|
||||
}
|
||||
|
||||
@@ -195,7 +195,7 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.networksDisabled {
|
||||
if s.checkNetworksDisabled() {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
|
||||
}
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/expose"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
sleephandler "github.com/netbirdio/netbird/client/internal/sleep/handler"
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
@@ -71,7 +72,13 @@ type Server struct {
|
||||
mutex sync.Mutex
|
||||
config *profilemanager.Config
|
||||
proto.UnimplementedDaemonServiceServer
|
||||
clientRunning bool // protected by mutex
|
||||
// clientRunning tracks "the daemon wants to be connected" — set true by
|
||||
// Start / Up, cleared by Down / Logout. Persists across retry
|
||||
// loops, signal disconnects, and ErrResetConnection cycles. NOT
|
||||
// changed by connectWithRetryRuns goroutine exit — for that
|
||||
// (goroutine-still-alive) check, see connectionGoroutineRunning() which
|
||||
// derives from clientGiveUpChan close state. Protected by s.mutex.
|
||||
clientRunning bool
|
||||
clientRunningChan chan struct{}
|
||||
clientGiveUpChan chan struct{} // closed when connectWithRetryRuns goroutine exits
|
||||
|
||||
@@ -98,6 +105,11 @@ type Server struct {
|
||||
|
||||
sleepHandler *sleephandler.SleepHandler
|
||||
|
||||
// mdmTicker periodically re-reads the OS-native MDM policy and triggers
|
||||
// an engine restart when the policy changes. Launched once by Start;
|
||||
// stopped by the rootCtx cancellation.
|
||||
mdmTicker *mdm.Ticker
|
||||
|
||||
updateManager *updater.Manager
|
||||
|
||||
jwtCache *jwtCache
|
||||
@@ -155,6 +167,17 @@ func (s *Server) Start() error {
|
||||
s.updateManager.CheckUpdateSuccess(s.rootCtx)
|
||||
}
|
||||
|
||||
// MDM policy reload ticker: every minute the desktop daemon re-reads
|
||||
// the OS-native managed-config store and, on diff vs the previous
|
||||
// observation, cancels the active engine context so connectWithRetry-
|
||||
// Runs re-resolves Config (re-running profilemanager.Config.apply which
|
||||
// applies the freshly-read MDM policy as the last layer) and brings
|
||||
// the engine back with the new values.
|
||||
if s.mdmTicker == nil {
|
||||
s.mdmTicker = mdm.NewTicker(mdm.DefaultReloadInterval)
|
||||
go s.mdmTicker.Run(s.rootCtx, s.onMDMPolicyChange)
|
||||
}
|
||||
|
||||
// if current state contains any error, return it
|
||||
// in all other cases we can continue execution only if status is idle and up command was
|
||||
// not in the progress or already successfully established connection.
|
||||
@@ -213,17 +236,27 @@ func (s *Server) Start() error {
|
||||
s.clientRunningChan = make(chan struct{})
|
||||
s.clientGiveUpChan = make(chan struct{})
|
||||
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
|
||||
s.publishConfigChangedEvent("startup")
|
||||
return nil
|
||||
}
|
||||
|
||||
// connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional
|
||||
// mechanism to keep the client connected even when the connection is lost.
|
||||
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
|
||||
//
|
||||
// The goroutine's exit is signalled to the daemon via close(giveUpChan)
|
||||
// — placed in the function-scope defer so every return path (panic,
|
||||
// DisableAutoConnect early-exit, backoff exhausted, ctx cancel) closes
|
||||
// it. Callers that need to observe "is the goroutine still alive?" use
|
||||
// Server.connectionGoroutineRunning() which non-blockingly checks the close state
|
||||
// of clientGiveUpChan. The defer does NOT touch s.mutex; the daemon's
|
||||
// "intent" (clientRunning) is maintained by the RPC handlers, not by this
|
||||
// goroutine.
|
||||
func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) {
|
||||
defer func() {
|
||||
s.mutex.Lock()
|
||||
s.clientRunning = false
|
||||
s.mutex.Unlock()
|
||||
if giveUpChan != nil {
|
||||
close(giveUpChan)
|
||||
}
|
||||
}()
|
||||
|
||||
if s.config.DisableAutoConnect {
|
||||
@@ -269,9 +302,26 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
|
||||
if err := backoff.Retry(runOperation, backOff); err != nil {
|
||||
log.Errorf("operation failed: %v", err)
|
||||
}
|
||||
// giveUpChan is closed by the function-scope defer.
|
||||
}
|
||||
|
||||
if giveUpChan != nil {
|
||||
close(giveUpChan)
|
||||
// connectionGoroutineRunning reports whether the connectWithRetryRuns goroutine is
|
||||
// still running. Returns false when no goroutine has ever been started
|
||||
// AND when the most recent one has already closed clientGiveUpChan on
|
||||
// exit (whether due to ctx cancel, DisableAutoConnect single-shot
|
||||
// completion, or backoff retry exhaustion).
|
||||
//
|
||||
// MUST be called with s.mutex held — accesses s.clientGiveUpChan which
|
||||
// is written by Start/Up under the same lock.
|
||||
func (s *Server) connectionGoroutineRunning() bool {
|
||||
if s.clientGiveUpChan == nil {
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case <-s.clientGiveUpChan:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -304,54 +354,85 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.checkUpdateSettingsDisabled() {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errUpdateSettingsDisabled)
|
||||
// Skip the update-settings gate when the request carries no actual
|
||||
// overrides: the CLI builds a SetConfigRequest unconditionally on
|
||||
// every `netbird up` (setupSetConfigReq in cmd/up.go), so a plain
|
||||
// `netbird up` would otherwise always trip the gate and surface a
|
||||
// misleading "setConfig method is not available" warning, even when
|
||||
// the user did not pass any config flag.
|
||||
if setConfigRequestHasConfigOverrides(msg) {
|
||||
if s.checkUpdateSettingsDisabled() {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errUpdateSettingsDisabled)
|
||||
}
|
||||
}
|
||||
|
||||
// MDM gate: refuse the whole request if any of its fields is enforced
|
||||
// by the active MDM policy. The error carries an MDMManagedFields-
|
||||
// Violation detail listing the offending key names. Non-conflicting
|
||||
// fields in the same request are not applied either.
|
||||
policy := loadMDMPolicy()
|
||||
if err := rejectMDMManagedFieldConflicts(mdmManagedFieldConflicts(msg, policy)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config, err := setConfigInputFromRequest(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := profilemanager.UpdateConfig(config); err != nil {
|
||||
log.Errorf("failed to update profile config: %v", err)
|
||||
return nil, fmt.Errorf("failed to update profile config: %w", err)
|
||||
}
|
||||
|
||||
return &proto.SetConfigResponse{}, nil
|
||||
}
|
||||
|
||||
// setConfigInputFromRequest translates a SetConfigRequest into the
|
||||
// profilemanager.ConfigInput that profilemanager.UpdateConfig consumes.
|
||||
// Pure mapping with no business logic beyond presence-aware copying of
|
||||
// optional fields and the "empty / clean" semantics for the two slice
|
||||
// fields (DNS labels, NAT external IPs). Extracted from SetConfig to
|
||||
// keep the handler's cognitive complexity below the SonarCube
|
||||
// threshold; the body is intentionally linear because each proto
|
||||
// field is its own optional case. Returns the resolved ConfigInput
|
||||
// and a non-nil error only when the active profile file path cannot
|
||||
// be determined.
|
||||
func setConfigInputFromRequest(msg *proto.SetConfigRequest) (profilemanager.ConfigInput, error) {
|
||||
var config profilemanager.ConfigInput
|
||||
|
||||
profState := profilemanager.ActiveProfileState{
|
||||
Name: msg.ProfileName,
|
||||
Username: msg.Username,
|
||||
}
|
||||
|
||||
profPath, err := profState.FilePath()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get active profile file path: %v", err)
|
||||
return nil, fmt.Errorf("failed to get active profile file path: %w", err)
|
||||
return config, fmt.Errorf("failed to get active profile file path: %w", err)
|
||||
}
|
||||
|
||||
var config profilemanager.ConfigInput
|
||||
|
||||
config.ConfigPath = profPath
|
||||
|
||||
if msg.ManagementUrl != "" {
|
||||
config.ManagementURL = msg.ManagementUrl
|
||||
}
|
||||
|
||||
if msg.AdminURL != "" {
|
||||
config.AdminURL = msg.AdminURL
|
||||
}
|
||||
|
||||
if msg.InterfaceName != nil {
|
||||
config.InterfaceName = msg.InterfaceName
|
||||
}
|
||||
|
||||
if msg.WireguardPort != nil {
|
||||
wgPort := int(*msg.WireguardPort)
|
||||
config.WireguardPort = &wgPort
|
||||
}
|
||||
|
||||
if msg.OptionalPreSharedKey != nil {
|
||||
if *msg.OptionalPreSharedKey != "" {
|
||||
config.PreSharedKey = msg.OptionalPreSharedKey
|
||||
}
|
||||
if msg.OptionalPreSharedKey != nil && *msg.OptionalPreSharedKey != "" {
|
||||
config.PreSharedKey = msg.OptionalPreSharedKey
|
||||
}
|
||||
|
||||
if msg.CleanDNSLabels {
|
||||
config.DNSLabels = domain.List{}
|
||||
|
||||
} else if msg.DnsLabels != nil {
|
||||
dnsLabels := domain.FromPunycodeList(msg.DnsLabels)
|
||||
config.DNSLabels = dnsLabels
|
||||
config.DNSLabels = domain.FromPunycodeList(msg.DnsLabels)
|
||||
}
|
||||
|
||||
if msg.CleanNATExternalIPs {
|
||||
@@ -364,7 +445,6 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
||||
if string(msg.CustomDNSAddress) == "empty" {
|
||||
config.CustomDNSAddress = []byte{}
|
||||
}
|
||||
|
||||
config.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist
|
||||
|
||||
if msg.DnsRouteInterval != nil {
|
||||
@@ -397,22 +477,31 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
||||
ttl := int(*msg.SshJWTCacheTTL)
|
||||
config.SSHJWTCacheTTL = &ttl
|
||||
}
|
||||
|
||||
if msg.Mtu != nil {
|
||||
mtu := uint16(*msg.Mtu)
|
||||
config.MTU = &mtu
|
||||
}
|
||||
|
||||
if _, err := profilemanager.UpdateConfig(config); err != nil {
|
||||
log.Errorf("failed to update profile config: %v", err)
|
||||
return nil, fmt.Errorf("failed to update profile config: %w", err)
|
||||
}
|
||||
|
||||
return &proto.SetConfigResponse{}, nil
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// Login uses setup key to prepare configuration for the daemon.
|
||||
func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*proto.LoginResponse, error) {
|
||||
// Config-override gates. LoginRequest carries the same surface as
|
||||
// SetConfigRequest (managementUrl, PSK, ssh/rosenpass/port toggles,
|
||||
// ...), so the same protections must apply. Without these the CLI
|
||||
// command `netbird up --management-url=X` (which falls through to
|
||||
// Login when SetConfig is rejected — see cmd/up.go) would silently
|
||||
// bypass `--disable-update-settings` and any MDM policy.
|
||||
if loginRequestHasConfigOverrides(msg) {
|
||||
if s.checkUpdateSettingsDisabled() {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errUpdateSettingsDisabled)
|
||||
}
|
||||
policy := loadMDMPolicy()
|
||||
if err := rejectMDMManagedFieldConflicts(loginRequestMDMConflicts(msg, policy)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
if s.actCancel != nil {
|
||||
s.actCancel()
|
||||
@@ -652,7 +741,13 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
|
||||
// Up starts engine work in the daemon.
|
||||
func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpResponse, error) {
|
||||
s.mutex.Lock()
|
||||
if s.clientRunning {
|
||||
// clientRunning is the daemon-intent flag (set by previous Up/Start, cleared
|
||||
// by Down). connectionGoroutineRunning() reports whether the previous retry-loop
|
||||
// goroutine is still trying. When intent is up AND goroutine is alive,
|
||||
// the existing engine is on the job — just wait for it. When intent
|
||||
// is up but the goroutine has given up (backoff exhausted) OR when
|
||||
// intent is down, fall through to spawn a fresh retry loop.
|
||||
if s.clientRunning && s.connectionGoroutineRunning() {
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
status, err := state.Status()
|
||||
if err != nil {
|
||||
@@ -743,6 +838,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
s.clientGiveUpChan = make(chan struct{})
|
||||
|
||||
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
|
||||
s.publishConfigChangedEvent("up_rpc")
|
||||
|
||||
s.mutex.Unlock()
|
||||
return s.waitForUp(callerCtx)
|
||||
@@ -871,6 +967,12 @@ func (s *Server) cleanupConnection() error {
|
||||
return ErrServiceNotUp
|
||||
}
|
||||
|
||||
// Daemon intent flips to "down" — all callers (Down RPC,
|
||||
// Logout RPC handlers) tear down the connection because the user
|
||||
// explicitly asked for it. MDM restart does NOT go through this
|
||||
// path, so its clientRunning stays true.
|
||||
s.clientRunning = false
|
||||
|
||||
// Capture the engine reference before cancelling the context.
|
||||
// After actCancel(), the connectWithRetryRuns goroutine wakes up
|
||||
// and sets connectClient.engine = nil, causing connectClient.Stop()
|
||||
@@ -1074,10 +1176,14 @@ func (s *Server) Status(
|
||||
msg *proto.StatusRequest,
|
||||
) (*proto.StatusResponse, error) {
|
||||
s.mutex.Lock()
|
||||
clientRunning := s.clientRunning
|
||||
// Only wait if the retry-loop goroutine is alive and making
|
||||
// progress. clientRunning=true with connectionGoroutineRunning=false means the
|
||||
// backoff has given up — there is nothing to wait for; let the
|
||||
// caller observe the failed status directly.
|
||||
alive := s.connectionGoroutineRunning()
|
||||
s.mutex.Unlock()
|
||||
|
||||
if msg.WaitForReady != nil && *msg.WaitForReady && clientRunning {
|
||||
if msg.WaitForReady != nil && *msg.WaitForReady && alive {
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
status, err := state.Status()
|
||||
if err != nil {
|
||||
@@ -1548,6 +1654,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
||||
EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
|
||||
DisableSSHAuth: disableSSHAuth,
|
||||
SshJWTCacheTTL: sshJWTCacheTTL,
|
||||
MDMManagedFields: cfg.Policy().ManagedKeys(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1646,7 +1753,7 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest)
|
||||
features := &proto.GetFeaturesResponse{
|
||||
DisableProfiles: s.checkProfilesDisabled(),
|
||||
DisableUpdateSettings: s.checkUpdateSettingsDisabled(),
|
||||
DisableNetworks: s.networksDisabled,
|
||||
DisableNetworks: s.checkNetworksDisabled(),
|
||||
}
|
||||
|
||||
return features, nil
|
||||
@@ -1668,22 +1775,46 @@ func (s *Server) connect(ctx context.Context, config *profilemanager.Config, sta
|
||||
return nil
|
||||
}
|
||||
|
||||
// MDM authority: when the platform-native MDM source sets a kill switch
|
||||
// key (regardless of true/false value), that value wins. The CLI flag
|
||||
// supplied at service install time is the fallback used only when the
|
||||
// MDM source is silent on the key. This honors the "MDM decides
|
||||
// everything" semantic agreed for NET-1214 — an admin pushing
|
||||
// disableX=false via MDM explicitly re-enables the feature even on a
|
||||
// box installed with --disable-X.
|
||||
func (s *Server) checkProfilesDisabled() bool {
|
||||
// Check if the environment variable is set to disable profiles
|
||||
if s.profilesDisabled {
|
||||
return true
|
||||
if s.config != nil {
|
||||
if v, ok := s.config.Policy().GetBool(mdm.KeyDisableProfiles); ok {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return s.profilesDisabled
|
||||
}
|
||||
|
||||
return false
|
||||
// checkNetworksDisabled reports whether the networks/exit-node feature
|
||||
// is disabled on this daemon instance. Resolved MDM-first: when the
|
||||
// active policy declares mdm.KeyDisableNetworks the policy value wins
|
||||
// (regardless of true/false), so an admin can re-enable the feature
|
||||
// via MDM even on a host that was installed with --disable-networks.
|
||||
// Falls back to the s.networksDisabled CLI flag when the policy is
|
||||
// silent on the key. Mirrors checkProfilesDisabled and
|
||||
// checkUpdateSettingsDisabled.
|
||||
func (s *Server) checkNetworksDisabled() bool {
|
||||
if s.config != nil {
|
||||
if v, ok := s.config.Policy().GetBool(mdm.KeyDisableNetworks); ok {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return s.networksDisabled
|
||||
}
|
||||
|
||||
func (s *Server) checkUpdateSettingsDisabled() bool {
|
||||
// Check if the environment variable is set to disable profiles
|
||||
if s.updateSettingsDisabled {
|
||||
return true
|
||||
if s.config != nil {
|
||||
if v, ok := s.config.Policy().GetBool(mdm.KeyDisableUpdateSettings); ok {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
return s.updateSettingsDisabled
|
||||
}
|
||||
|
||||
func (s *Server) startUpdateManagerForGUI() {
|
||||
|
||||
@@ -101,6 +101,7 @@ func TestCleanupConnection_ClearsConnectClient(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, s.connectClient, "connectClient should be nil after cleanup")
|
||||
assert.False(t, s.clientRunning, "clientRunning should be cleared after cleanup (intent = down)")
|
||||
}
|
||||
|
||||
// TestCleanState_NilConnectClient validates that CleanState doesn't panic
|
||||
@@ -144,17 +145,20 @@ func TestDownThenUp_StaleRunningChan(t *testing.T) {
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
s.actCancel = cancel
|
||||
|
||||
// Simulate Down(): cleanupConnection sets connectClient = nil
|
||||
// Simulate Down(): cleanupConnection sets connectClient = nil and
|
||||
// flips clientRunning to false (intent = down). The connectionGoroutineRunning state
|
||||
// remains independent of intent — derived from clientGiveUpChan.
|
||||
s.mutex.Lock()
|
||||
err := s.cleanupConnection()
|
||||
s.mutex.Unlock()
|
||||
require.NoError(t, err)
|
||||
|
||||
// After cleanup: connectClient is nil, clientRunning still true
|
||||
// (goroutine hasn't exited yet)
|
||||
// After cleanup: connectClient is nil, clientRunning is false (intent
|
||||
// cleared by cleanupConnection), connectionGoroutineRunning may still be true
|
||||
// (goroutine teardown is independent of the intent flag).
|
||||
s.mutex.Lock()
|
||||
assert.Nil(t, s.connectClient, "connectClient should be nil after cleanup")
|
||||
assert.True(t, s.clientRunning, "clientRunning still true until goroutine exits")
|
||||
assert.False(t, s.clientRunning, "clientRunning should be cleared by cleanupConnection (intent = down)")
|
||||
s.mutex.Unlock()
|
||||
|
||||
// waitForUp() returns immediately due to stale closed clientRunningChan
|
||||
|
||||
235
client/server/server_privileged_test.go
Normal file
235
client/server/server_privileged_test.go
Normal file
@@ -0,0 +1,235 @@
|
||||
//go:build privileged
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"os/user"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||
)
|
||||
|
||||
var (
|
||||
kaep = keepalive.EnforcementPolicy{
|
||||
MinTime: 15 * time.Second,
|
||||
PermitWithoutStream: true,
|
||||
}
|
||||
|
||||
kasp = keepalive.ServerParameters{
|
||||
MaxConnectionIdle: 15 * time.Second,
|
||||
MaxConnectionAgeGrace: 5 * time.Second,
|
||||
Time: 5 * time.Second,
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
)
|
||||
|
||||
// TestConnectWithRetryRuns checks that the connectWithRetry function runs and runs the retries according to the times specified via environment variables
|
||||
// we will use a management server started via to simulate the server and capture the number of retries
|
||||
func TestConnectWithRetryRuns(t *testing.T) {
|
||||
// start the signal server
|
||||
_, signalAddr, err := startSignal(t)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start signal server: %v", err)
|
||||
}
|
||||
|
||||
counter := 0
|
||||
// start the management server
|
||||
_, mgmtAddr, err := startManagement(t, signalAddr, &counter)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start management server: %v", err)
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
|
||||
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(30*time.Second))
|
||||
defer cancel()
|
||||
// create new server
|
||||
ic := profilemanager.ConfigInput{
|
||||
ManagementURL: "http://" + mgmtAddr,
|
||||
ConfigPath: t.TempDir() + "/test-profile.json",
|
||||
}
|
||||
|
||||
config, err := profilemanager.UpdateOrCreateConfig(ic)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create config: %v", err)
|
||||
}
|
||||
|
||||
currUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
pm := profilemanager.ServiceManager{}
|
||||
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
Name: "test-profile",
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to set active profile state: %v", err)
|
||||
}
|
||||
|
||||
s := New(ctx, "debug", "", false, false, false, false)
|
||||
|
||||
s.config = config
|
||||
|
||||
s.statusRecorder = peer.NewRecorder(config.ManagementURL.String())
|
||||
t.Setenv(retryInitialIntervalVar, "1s")
|
||||
t.Setenv(maxRetryIntervalVar, "2s")
|
||||
t.Setenv(maxRetryTimeVar, "5s")
|
||||
t.Setenv(retryMultiplierVar, "1")
|
||||
|
||||
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil)
|
||||
if counter < 3 {
|
||||
t.Fatalf("expected counter > 2, got %d", counter)
|
||||
}
|
||||
}
|
||||
|
||||
type mockServer struct {
|
||||
mgmtProto.ManagementServiceServer
|
||||
counter *int
|
||||
}
|
||||
|
||||
func (m *mockServer) Login(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
|
||||
*m.counter++
|
||||
return m.ManagementServiceServer.Login(ctx, req)
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
dataDir := t.TempDir()
|
||||
|
||||
config := &config.Config{
|
||||
Stuns: []*config.Host{},
|
||||
TURNConfig: &config.TURNConfig{},
|
||||
Signal: &config.Host{
|
||||
Proto: "http",
|
||||
URI: signalAddr,
|
||||
},
|
||||
Datadir: dataDir,
|
||||
HttpConfig: nil,
|
||||
}
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
peersManager := peers.NewManager(store, permissionsManagerMock)
|
||||
settingsManagerMock := settings.NewMockManager(ctrl)
|
||||
|
||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||
|
||||
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mock := &mockServer{
|
||||
ManagementServiceServer: mgmtServer,
|
||||
counter: counter,
|
||||
}
|
||||
mgmtProto.RegisterManagementServiceServer(s, mock)
|
||||
go func() {
|
||||
if err = s.Serve(lis); err != nil {
|
||||
log.Fatalf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
log.Fatalf("failed to listen: %v", err)
|
||||
}
|
||||
|
||||
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
|
||||
require.NoError(t, err)
|
||||
proto.RegisterSignalExchangeServer(s, srv)
|
||||
|
||||
go func() {
|
||||
if err = s.Serve(lis); err != nil {
|
||||
log.Fatalf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
@@ -2,124 +2,22 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/url"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
daemonProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||
)
|
||||
|
||||
var (
|
||||
kaep = keepalive.EnforcementPolicy{
|
||||
MinTime: 15 * time.Second,
|
||||
PermitWithoutStream: true,
|
||||
}
|
||||
|
||||
kasp = keepalive.ServerParameters{
|
||||
MaxConnectionIdle: 15 * time.Second,
|
||||
MaxConnectionAgeGrace: 5 * time.Second,
|
||||
Time: 5 * time.Second,
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
)
|
||||
|
||||
// TestConnectWithRetryRuns checks that the connectWithRetry function runs and runs the retries according to the times specified via environment variables
|
||||
// we will use a management server started via to simulate the server and capture the number of retries
|
||||
func TestConnectWithRetryRuns(t *testing.T) {
|
||||
// start the signal server
|
||||
_, signalAddr, err := startSignal(t)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start signal server: %v", err)
|
||||
}
|
||||
|
||||
counter := 0
|
||||
// start the management server
|
||||
_, mgmtAddr, err := startManagement(t, signalAddr, &counter)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start management server: %v", err)
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
|
||||
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(30*time.Second))
|
||||
defer cancel()
|
||||
// create new server
|
||||
ic := profilemanager.ConfigInput{
|
||||
ManagementURL: "http://" + mgmtAddr,
|
||||
ConfigPath: t.TempDir() + "/test-profile.json",
|
||||
}
|
||||
|
||||
config, err := profilemanager.UpdateOrCreateConfig(ic)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create config: %v", err)
|
||||
}
|
||||
|
||||
currUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
pm := profilemanager.ServiceManager{}
|
||||
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
Name: "test-profile",
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to set active profile state: %v", err)
|
||||
}
|
||||
|
||||
s := New(ctx, "debug", "", false, false, false, false)
|
||||
|
||||
s.config = config
|
||||
|
||||
s.statusRecorder = peer.NewRecorder(config.ManagementURL.String())
|
||||
t.Setenv(retryInitialIntervalVar, "1s")
|
||||
t.Setenv(maxRetryIntervalVar, "2s")
|
||||
t.Setenv(maxRetryTimeVar, "5s")
|
||||
t.Setenv(retryMultiplierVar, "1")
|
||||
|
||||
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil)
|
||||
if counter < 3 {
|
||||
t.Fatalf("expected counter > 2, got %d", counter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_Up(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
origDefaultProfileDir := profilemanager.DefaultConfigPathDir
|
||||
@@ -259,119 +157,3 @@ func TestServer_SubcribeEvents(t *testing.T) {
|
||||
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
type mockServer struct {
|
||||
mgmtProto.ManagementServiceServer
|
||||
counter *int
|
||||
}
|
||||
|
||||
func (m *mockServer) Login(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
|
||||
*m.counter++
|
||||
return m.ManagementServiceServer.Login(ctx, req)
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
dataDir := t.TempDir()
|
||||
|
||||
config := &config.Config{
|
||||
Stuns: []*config.Host{},
|
||||
TURNConfig: &config.TURNConfig{},
|
||||
Signal: &config.Host{
|
||||
Proto: "http",
|
||||
URI: signalAddr,
|
||||
},
|
||||
Datadir: dataDir,
|
||||
HttpConfig: nil,
|
||||
}
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
peersManager := peers.NewManager(store, permissionsManagerMock)
|
||||
settingsManagerMock := settings.NewMockManager(ctrl)
|
||||
|
||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||
|
||||
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mock := &mockServer{
|
||||
ManagementServiceServer: mgmtServer,
|
||||
counter: counter,
|
||||
}
|
||||
mgmtProto.RegisterManagementServiceServer(s, mock)
|
||||
go func() {
|
||||
if err = s.Serve(lis); err != nil {
|
||||
log.Fatalf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
log.Fatalf("failed to listen: %v", err)
|
||||
}
|
||||
|
||||
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
|
||||
require.NoError(t, err)
|
||||
proto.RegisterSignalExchangeServer(s, srv)
|
||||
|
||||
go func() {
|
||||
if err = s.Serve(lis); err != nil {
|
||||
log.Fatalf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
198
client/server/setconfig_mdm_test.go
Normal file
198
client/server/setconfig_mdm_test.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// withMDMPolicy temporarily overrides the server-package loadMDMPolicy hook
|
||||
// so SetConfig observes the supplied Policy. Restores the original loader
|
||||
// at test cleanup.
|
||||
func withMDMPolicy(t *testing.T, policy *mdm.Policy) {
|
||||
t.Helper()
|
||||
prev := loadMDMPolicy
|
||||
loadMDMPolicy = func() *mdm.Policy { return policy }
|
||||
t.Cleanup(func() { loadMDMPolicy = prev })
|
||||
}
|
||||
|
||||
// setupServerWithProfile mirrors the boilerplate of TestSetConfig_AllFieldsSaved:
|
||||
// overrides profilemanager paths to a temp dir, seeds a profile, sets it
|
||||
// active, and constructs a Server instance. Returns the constructed server
|
||||
// plus context + profile name + username + cfgPath for the seeded profile.
|
||||
func setupServerWithProfile(t *testing.T) (s *Server, ctx context.Context, profName, username, cfgPath string) {
|
||||
t.Helper()
|
||||
tempDir := t.TempDir()
|
||||
|
||||
origDefaultProfileDir := profilemanager.DefaultConfigPathDir
|
||||
origDefaultConfigPath := profilemanager.DefaultConfigPath
|
||||
origActiveProfileStatePath := profilemanager.ActiveProfileStatePath
|
||||
profilemanager.ConfigDirOverride = tempDir
|
||||
profilemanager.DefaultConfigPathDir = tempDir
|
||||
profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json"
|
||||
profilemanager.DefaultConfigPath = filepath.Join(tempDir, "default.json")
|
||||
t.Cleanup(func() {
|
||||
profilemanager.DefaultConfigPathDir = origDefaultProfileDir
|
||||
profilemanager.ActiveProfileStatePath = origActiveProfileStatePath
|
||||
profilemanager.DefaultConfigPath = origDefaultConfigPath
|
||||
profilemanager.ConfigDirOverride = ""
|
||||
})
|
||||
|
||||
currUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
profName = "test-profile-mdm"
|
||||
cfgPath = filepath.Join(tempDir, profName+".json")
|
||||
|
||||
_, err = profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||
ConfigPath: cfgPath,
|
||||
ManagementURL: "https://api.netbird.io:443",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
pm := profilemanager.ServiceManager{}
|
||||
require.NoError(t, pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
Name: profName,
|
||||
Username: currUser.Username,
|
||||
}))
|
||||
|
||||
ctx = context.Background()
|
||||
s = New(ctx, "console", "", false, false, false, false)
|
||||
return s, ctx, profName, currUser.Username, cfgPath
|
||||
}
|
||||
|
||||
// extractViolation pulls the MDMManagedFieldsViolation detail from a
|
||||
// FailedPrecondition error. Fails the test if absent or malformed.
|
||||
func extractViolation(t *testing.T, err error) *proto.MDMManagedFieldsViolation {
|
||||
t.Helper()
|
||||
require.Error(t, err)
|
||||
st, ok := gstatus.FromError(err)
|
||||
require.True(t, ok, "error must be a gRPC status: %v", err)
|
||||
require.Equal(t, codes.FailedPrecondition, st.Code(), "expected FailedPrecondition, got %s", st.Code())
|
||||
for _, d := range st.Details() {
|
||||
if v, ok := d.(*proto.MDMManagedFieldsViolation); ok {
|
||||
return v
|
||||
}
|
||||
}
|
||||
t.Fatalf("MDMManagedFieldsViolation detail not found on status; details: %v", st.Details())
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSetConfig_MDMReject_SingleField(t *testing.T) {
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: "https://mdm.example.com:443",
|
||||
}))
|
||||
|
||||
s, ctx, profName, username, _ := setupServerWithProfile(t)
|
||||
|
||||
_, err := s.SetConfig(ctx, &proto.SetConfigRequest{
|
||||
ProfileName: profName,
|
||||
Username: username,
|
||||
ManagementUrl: "https://user.tried.this.com:443",
|
||||
})
|
||||
|
||||
v := extractViolation(t, err)
|
||||
assert.Equal(t, []string{mdm.KeyManagementURL}, v.GetFields())
|
||||
}
|
||||
|
||||
func TestSetConfig_MDMReject_MultipleFields(t *testing.T) {
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: "https://mdm.example.com:443",
|
||||
mdm.KeyBlockInbound: true,
|
||||
mdm.KeyRosenpassEnabled: true,
|
||||
}))
|
||||
|
||||
s, ctx, profName, username, _ := setupServerWithProfile(t)
|
||||
|
||||
blockInbound := false
|
||||
rosenpassEnabled := false
|
||||
_, err := s.SetConfig(ctx, &proto.SetConfigRequest{
|
||||
ProfileName: profName,
|
||||
Username: username,
|
||||
ManagementUrl: "https://user.tried.this.com:443",
|
||||
BlockInbound: &blockInbound,
|
||||
RosenpassEnabled: &rosenpassEnabled,
|
||||
})
|
||||
|
||||
v := extractViolation(t, err)
|
||||
assert.ElementsMatch(t, []string{
|
||||
mdm.KeyManagementURL,
|
||||
mdm.KeyBlockInbound,
|
||||
mdm.KeyRosenpassEnabled,
|
||||
}, v.GetFields())
|
||||
}
|
||||
|
||||
func TestSetConfig_MDMReject_AllOrNothing(t *testing.T) {
|
||||
// MDM enforces ManagementURL only; user request touches both the
|
||||
// enforced field AND a non-enforced field (RosenpassEnabled).
|
||||
// The whole request must be rejected — non-conflicting fields are not
|
||||
// applied either.
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: "https://mdm.example.com:443",
|
||||
}))
|
||||
|
||||
s, ctx, profName, username, cfgPath := setupServerWithProfile(t)
|
||||
|
||||
rosenpassEnabled := true
|
||||
_, err := s.SetConfig(ctx, &proto.SetConfigRequest{
|
||||
ProfileName: profName,
|
||||
Username: username,
|
||||
ManagementUrl: "https://user.tried.this.com:443",
|
||||
RosenpassEnabled: &rosenpassEnabled,
|
||||
})
|
||||
|
||||
v := extractViolation(t, err)
|
||||
assert.Equal(t, []string{mdm.KeyManagementURL}, v.GetFields())
|
||||
|
||||
// Confirm RosenpassEnabled was NOT applied even though it was not
|
||||
// in the conflict list: the request was rejected as a whole.
|
||||
reloaded, err := profilemanager.GetConfig(cfgPath)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, reloaded.RosenpassEnabled, "non-conflicting field must not be applied when request is rejected")
|
||||
}
|
||||
|
||||
func TestSetConfig_MDMAllow_NonManagedFields(t *testing.T) {
|
||||
// MDM enforces ManagementURL but the user only writes RosenpassEnabled.
|
||||
// Request must succeed.
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: "https://mdm.example.com:443",
|
||||
}))
|
||||
|
||||
s, ctx, profName, username, _ := setupServerWithProfile(t)
|
||||
|
||||
rosenpassEnabled := true
|
||||
resp, err := s.SetConfig(ctx, &proto.SetConfigRequest{
|
||||
ProfileName: profName,
|
||||
Username: username,
|
||||
RosenpassEnabled: &rosenpassEnabled,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
}
|
||||
|
||||
func TestSetConfig_MDMEmpty_NoEnforcement(t *testing.T) {
|
||||
// No MDM policy active: any field can be written.
|
||||
withMDMPolicy(t, mdm.NewPolicy(nil))
|
||||
|
||||
s, ctx, profName, username, _ := setupServerWithProfile(t)
|
||||
|
||||
resp, err := s.SetConfig(ctx, &proto.SetConfigRequest{
|
||||
ProfileName: profName,
|
||||
Username: username,
|
||||
ManagementUrl: "https://user.changed.url.com:443",
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
}
|
||||
118
client/ssh/client/client_privileged_test.go
Normal file
118
client/ssh/client/client_privileged_test.go
Normal file
@@ -0,0 +1,118 @@
|
||||
//go:build privileged
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
)
|
||||
|
||||
func TestSSHClient_CommandExecution(t *testing.T) {
|
||||
if runtime.GOOS == "windows" && testutil.IsCI() {
|
||||
t.Skip("Skipping Windows command execution tests in CI due to S4U authentication issues")
|
||||
}
|
||||
|
||||
server, _, client := setupTestSSHServerAndClient(t)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
defer func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
t.Run("ExecuteCommand captures output", func(t *testing.T) {
|
||||
output, err := client.ExecuteCommand(ctx, "echo hello")
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(output), "hello")
|
||||
})
|
||||
|
||||
t.Run("ExecuteCommandWithIO streams output", func(t *testing.T) {
|
||||
err := client.ExecuteCommandWithIO(ctx, "echo world")
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("commands with flags work", func(t *testing.T) {
|
||||
output, err := client.ExecuteCommand(ctx, "echo -n test_flag")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test_flag", strings.TrimSpace(string(output)))
|
||||
})
|
||||
|
||||
t.Run("non-zero exit codes don't return errors", func(t *testing.T) {
|
||||
var testCmd string
|
||||
if runtime.GOOS == "windows" {
|
||||
testCmd = "echo hello | Select-String notfound"
|
||||
} else {
|
||||
testCmd = "echo 'hello' | grep 'notfound'"
|
||||
}
|
||||
_, err := client.ExecuteCommand(ctx, testCmd)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSSHClient_ContextCancellation(t *testing.T) {
|
||||
server, serverAddr, _ := setupTestSSHServerAndClient(t)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
t.Run("connection with short timeout", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
_, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
if err != nil {
|
||||
// Check for actual timeout-related errors rather than string matching
|
||||
assert.True(t,
|
||||
errors.Is(err, context.DeadlineExceeded) ||
|
||||
errors.Is(err, context.Canceled) ||
|
||||
strings.Contains(err.Error(), "timeout"),
|
||||
"Expected timeout-related error, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("command execution cancellation", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if err := client.Close(); err != nil {
|
||||
t.Logf("client close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
cmdCtx, cmdCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cmdCancel()
|
||||
|
||||
err = client.ExecuteCommandWithPTY(cmdCtx, "sleep 10")
|
||||
if err != nil {
|
||||
var exitMissingErr *cryptossh.ExitMissingError
|
||||
isValidCancellation := errors.Is(err, context.DeadlineExceeded) ||
|
||||
errors.Is(err, context.Canceled) ||
|
||||
errors.As(err, &exitMissingErr)
|
||||
assert.True(t, isValidCancellation, "Should handle command cancellation properly")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
@@ -78,53 +77,6 @@ func TestSSHClient_DialWithKey(t *testing.T) {
|
||||
assert.NotNil(t, client.client)
|
||||
}
|
||||
|
||||
func TestSSHClient_CommandExecution(t *testing.T) {
|
||||
if runtime.GOOS == "windows" && testutil.IsCI() {
|
||||
t.Skip("Skipping Windows command execution tests in CI due to S4U authentication issues")
|
||||
}
|
||||
|
||||
server, _, client := setupTestSSHServerAndClient(t)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
defer func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
t.Run("ExecuteCommand captures output", func(t *testing.T) {
|
||||
output, err := client.ExecuteCommand(ctx, "echo hello")
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(output), "hello")
|
||||
})
|
||||
|
||||
t.Run("ExecuteCommandWithIO streams output", func(t *testing.T) {
|
||||
err := client.ExecuteCommandWithIO(ctx, "echo world")
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("commands with flags work", func(t *testing.T) {
|
||||
output, err := client.ExecuteCommand(ctx, "echo -n test_flag")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test_flag", strings.TrimSpace(string(output)))
|
||||
})
|
||||
|
||||
t.Run("non-zero exit codes don't return errors", func(t *testing.T) {
|
||||
var testCmd string
|
||||
if runtime.GOOS == "windows" {
|
||||
testCmd = "echo hello | Select-String notfound"
|
||||
} else {
|
||||
testCmd = "echo 'hello' | grep 'notfound'"
|
||||
}
|
||||
_, err := client.ExecuteCommand(ctx, testCmd)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSSHClient_ConnectionHandling(t *testing.T) {
|
||||
server, serverAddr, _ := setupTestSSHServerAndClient(t)
|
||||
defer func() {
|
||||
@@ -154,59 +106,6 @@ func TestSSHClient_ConnectionHandling(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHClient_ContextCancellation(t *testing.T) {
|
||||
server, serverAddr, _ := setupTestSSHServerAndClient(t)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
t.Run("connection with short timeout", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
_, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
if err != nil {
|
||||
// Check for actual timeout-related errors rather than string matching
|
||||
assert.True(t,
|
||||
errors.Is(err, context.DeadlineExceeded) ||
|
||||
errors.Is(err, context.Canceled) ||
|
||||
strings.Contains(err.Error(), "timeout"),
|
||||
"Expected timeout-related error, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("command execution cancellation", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if err := client.Close(); err != nil {
|
||||
t.Logf("client close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
cmdCtx, cmdCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cmdCancel()
|
||||
|
||||
err = client.ExecuteCommandWithPTY(cmdCtx, "sleep 10")
|
||||
if err != nil {
|
||||
var exitMissingErr *cryptossh.ExitMissingError
|
||||
isValidCancellation := errors.Is(err, context.DeadlineExceeded) ||
|
||||
errors.Is(err, context.Canceled) ||
|
||||
errors.As(err, &exitMissingErr)
|
||||
assert.True(t, isValidCancellation, "Should handle command cancellation properly")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSSHClient_NoAuthMode(t *testing.T) {
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
423
client/ssh/proxy/proxy_privileged_test.go
Normal file
423
client/ssh/proxy/proxy_privileged_test.go
Normal file
@@ -0,0 +1,423 @@
|
||||
//go:build privileged
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
"github.com/netbirdio/netbird/client/ssh/server"
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
func (m *mockDaemon) setJWTToken(token string) {
|
||||
m.impl.jwtToken = token
|
||||
}
|
||||
|
||||
func TestSSHProxy_Connect(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
// TODO: Windows test times out - user switching and command execution tested on Linux
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Skipping on Windows - covered by Linux tests")
|
||||
}
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
audience = "test-audience"
|
||||
)
|
||||
|
||||
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||
defer jwksServer.Close()
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &server.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: &server.JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audiences: []string{audience},
|
||||
KeysLocation: jwksURL,
|
||||
},
|
||||
}
|
||||
sshServer := server.New(serverConfig)
|
||||
sshServer.SetAllowRootLogin(true)
|
||||
|
||||
// Configure SSH authorization for the test user
|
||||
testUsername := testutil.GetTestUsername(t)
|
||||
testJWTUser := "test-username"
|
||||
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
authConfig := &sshauth.Config{
|
||||
UserIDClaim: sshauth.DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
testUsername: {0}, // Index 0 in AuthorizedUsers
|
||||
},
|
||||
}
|
||||
sshServer.UpdateSSHAuth(authConfig)
|
||||
|
||||
sshServerAddr := server.StartTestServer(t, sshServer)
|
||||
defer func() { _ = sshServer.Stop() }()
|
||||
|
||||
mockDaemon := startMockDaemon(t)
|
||||
defer mockDaemon.stop()
|
||||
|
||||
host, portStr, err := net.SplitHostPort(sshServerAddr)
|
||||
require.NoError(t, err)
|
||||
port, err := strconv.Atoi(portStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockDaemon.setHostKey(host, hostPubKey)
|
||||
|
||||
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
|
||||
mockDaemon.setJWTToken(validToken)
|
||||
|
||||
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientConn, proxyConn := net.Pipe()
|
||||
defer func() { _ = clientConn.Close() }()
|
||||
|
||||
origStdin := os.Stdin
|
||||
origStdout := os.Stdout
|
||||
defer func() {
|
||||
os.Stdin = origStdin
|
||||
os.Stdout = origStdout
|
||||
}()
|
||||
|
||||
stdinReader, stdinWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
stdoutReader, stdoutWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
|
||||
os.Stdin = stdinReader
|
||||
os.Stdout = stdoutWriter
|
||||
|
||||
go func() {
|
||||
_, _ = io.Copy(stdinWriter, proxyConn)
|
||||
}()
|
||||
go func() {
|
||||
_, _ = io.Copy(proxyConn, stdoutReader)
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
connectErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
connectErrCh <- proxyInstance.Connect(ctx)
|
||||
}()
|
||||
|
||||
sshConfig := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: []cryptossh.AuthMethod{},
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
|
||||
require.NoError(t, err, "Should connect to proxy server")
|
||||
defer func() { _ = sshClientConn.Close() }()
|
||||
|
||||
sshClient := cryptossh.NewClient(sshClientConn, chans, reqs)
|
||||
|
||||
session, err := sshClient.NewSession()
|
||||
require.NoError(t, err, "Should create session through full proxy to backend")
|
||||
|
||||
outputCh := make(chan []byte, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
output, err := session.Output("echo hello-from-proxy")
|
||||
outputCh <- output
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case output := <-outputCh:
|
||||
err := <-errCh
|
||||
require.NoError(t, err, "Command should execute successfully through proxy")
|
||||
assert.Contains(t, string(output), "hello-from-proxy", "Should receive command output through proxy")
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("Command execution timed out")
|
||||
}
|
||||
|
||||
_ = session.Close()
|
||||
_ = sshClient.Close()
|
||||
_ = clientConn.Close()
|
||||
cancel()
|
||||
}
|
||||
|
||||
// TestSSHProxy_CommandQuoting verifies that the proxy preserves shell quoting
|
||||
// when forwarding commands to the backend. This is critical for tools like
|
||||
// Ansible that send commands such as:
|
||||
//
|
||||
// /bin/sh -c '( umask 77 && mkdir -p ... ) && sleep 0'
|
||||
//
|
||||
// The single quotes must be preserved so the backend shell receives the
|
||||
// subshell expression as a single argument to -c.
|
||||
func TestSSHProxy_CommandQuoting(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
sshClient, cleanup := setupProxySSHClient(t)
|
||||
defer cleanup()
|
||||
|
||||
// These commands simulate what the SSH protocol delivers as exec payloads.
|
||||
// When a user types: ssh host '/bin/sh -c "( echo hello )"'
|
||||
// the local shell strips the outer single quotes, and the SSH exec request
|
||||
// contains the raw string: /bin/sh -c "( echo hello )"
|
||||
//
|
||||
// The proxy must forward this string verbatim. Using session.Command()
|
||||
// (shlex.Split + strings.Join) strips the inner double quotes, breaking
|
||||
// the command on the backend.
|
||||
tests := []struct {
|
||||
name string
|
||||
command string
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "subshell_in_double_quotes",
|
||||
command: `/bin/sh -c "( echo from-subshell ) && echo outer"`,
|
||||
expect: "from-subshell\nouter\n",
|
||||
},
|
||||
{
|
||||
name: "printf_with_special_chars",
|
||||
command: `/bin/sh -c "printf '%s\n' 'hello world'"`,
|
||||
expect: "hello world\n",
|
||||
},
|
||||
{
|
||||
name: "nested_command_substitution",
|
||||
command: `/bin/sh -c "echo $(echo nested)"`,
|
||||
expect: "nested\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
session, err := sshClient.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = session.Close() }()
|
||||
|
||||
var stderrBuf bytes.Buffer
|
||||
session.Stderr = &stderrBuf
|
||||
|
||||
outputCh := make(chan []byte, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
output, err := session.Output(tc.command)
|
||||
outputCh <- output
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case output := <-outputCh:
|
||||
err := <-errCh
|
||||
if stderrBuf.Len() > 0 {
|
||||
t.Logf("stderr: %s", stderrBuf.String())
|
||||
}
|
||||
require.NoError(t, err, "command should succeed: %s", tc.command)
|
||||
assert.Equal(t, tc.expect, string(output), "output mismatch for: %s", tc.command)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatalf("command timed out: %s", tc.command)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// setupProxySSHClient creates a full proxy test environment and returns
|
||||
// an SSH client connected through the proxy to a backend NetBird SSH server.
|
||||
func setupProxySSHClient(t *testing.T) (*cryptossh.Client, func()) {
|
||||
t.Helper()
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
audience = "test-audience"
|
||||
)
|
||||
|
||||
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &server.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: &server.JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audiences: []string{audience},
|
||||
KeysLocation: jwksURL,
|
||||
},
|
||||
}
|
||||
sshServer := server.New(serverConfig)
|
||||
sshServer.SetAllowRootLogin(true)
|
||||
|
||||
testUsername := testutil.GetTestUsername(t)
|
||||
testJWTUser := "test-username"
|
||||
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
authConfig := &sshauth.Config{
|
||||
UserIDClaim: sshauth.DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
testUsername: {0},
|
||||
},
|
||||
}
|
||||
sshServer.UpdateSSHAuth(authConfig)
|
||||
|
||||
sshServerAddr := server.StartTestServer(t, sshServer)
|
||||
|
||||
mockDaemon := startMockDaemon(t)
|
||||
|
||||
host, portStr, err := net.SplitHostPort(sshServerAddr)
|
||||
require.NoError(t, err)
|
||||
port, err := strconv.Atoi(portStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockDaemon.setHostKey(host, hostPubKey)
|
||||
|
||||
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
|
||||
mockDaemon.setJWTToken(validToken)
|
||||
|
||||
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
origStdin := os.Stdin
|
||||
origStdout := os.Stdout
|
||||
|
||||
stdinReader, stdinWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
stdoutReader, stdoutWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
|
||||
os.Stdin = stdinReader
|
||||
os.Stdout = stdoutWriter
|
||||
|
||||
clientConn, proxyConn := net.Pipe()
|
||||
|
||||
go func() { _, _ = io.Copy(stdinWriter, proxyConn) }()
|
||||
go func() { _, _ = io.Copy(proxyConn, stdoutReader) }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
|
||||
go func() {
|
||||
_ = proxyInstance.Connect(ctx)
|
||||
}()
|
||||
|
||||
sshConfig := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: []cryptossh.AuthMethod{},
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
client := cryptossh.NewClient(sshClientConn, chans, reqs)
|
||||
|
||||
cleanupFn := func() {
|
||||
_ = client.Close()
|
||||
_ = clientConn.Close()
|
||||
cancel()
|
||||
os.Stdin = origStdin
|
||||
os.Stdout = origStdout
|
||||
_ = sshServer.Stop()
|
||||
mockDaemon.stop()
|
||||
jwksServer.Close()
|
||||
}
|
||||
|
||||
return client, cleanupFn
|
||||
}
|
||||
|
||||
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
|
||||
t.Helper()
|
||||
privateKey, jwksJSON := generateTestJWKS(t)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if _, err := w.Write(jwksJSON); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}))
|
||||
|
||||
return server, privateKey, server.URL
|
||||
}
|
||||
|
||||
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
|
||||
t.Helper()
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
publicKey := &privateKey.PublicKey
|
||||
n := publicKey.N.Bytes()
|
||||
e := publicKey.E
|
||||
|
||||
jwk := nbjwt.JSONWebKey{
|
||||
Kty: "RSA",
|
||||
Kid: "test-key-id",
|
||||
Use: "sig",
|
||||
N: base64.RawURLEncoding.EncodeToString(n),
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(e)).Bytes()),
|
||||
}
|
||||
|
||||
jwks := nbjwt.Jwks{
|
||||
Keys: []nbjwt.JSONWebKey{jwk},
|
||||
}
|
||||
|
||||
jwksJSON, err := json.Marshal(jwks)
|
||||
require.NoError(t, err)
|
||||
|
||||
return privateKey, jwksJSON
|
||||
}
|
||||
|
||||
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string, user string) string {
|
||||
t.Helper()
|
||||
claims := jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": audience,
|
||||
"sub": user,
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
token.Header["kid"] = "test-key-id"
|
||||
|
||||
tokenString, err := token.SignedString(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
return tokenString
|
||||
}
|
||||
@@ -1,25 +1,12 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
@@ -28,11 +15,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
"github.com/netbirdio/netbird/client/ssh/server"
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
@@ -106,331 +89,6 @@ func TestSSHProxy_verifyHostKey(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestSSHProxy_Connect(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
// TODO: Windows test times out - user switching and command execution tested on Linux
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Skipping on Windows - covered by Linux tests")
|
||||
}
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
audience = "test-audience"
|
||||
)
|
||||
|
||||
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||
defer jwksServer.Close()
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &server.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: &server.JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audiences: []string{audience},
|
||||
KeysLocation: jwksURL,
|
||||
},
|
||||
}
|
||||
sshServer := server.New(serverConfig)
|
||||
sshServer.SetAllowRootLogin(true)
|
||||
|
||||
// Configure SSH authorization for the test user
|
||||
testUsername := testutil.GetTestUsername(t)
|
||||
testJWTUser := "test-username"
|
||||
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
authConfig := &sshauth.Config{
|
||||
UserIDClaim: sshauth.DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
testUsername: {0}, // Index 0 in AuthorizedUsers
|
||||
},
|
||||
}
|
||||
sshServer.UpdateSSHAuth(authConfig)
|
||||
|
||||
sshServerAddr := server.StartTestServer(t, sshServer)
|
||||
defer func() { _ = sshServer.Stop() }()
|
||||
|
||||
mockDaemon := startMockDaemon(t)
|
||||
defer mockDaemon.stop()
|
||||
|
||||
host, portStr, err := net.SplitHostPort(sshServerAddr)
|
||||
require.NoError(t, err)
|
||||
port, err := strconv.Atoi(portStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockDaemon.setHostKey(host, hostPubKey)
|
||||
|
||||
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
|
||||
mockDaemon.setJWTToken(validToken)
|
||||
|
||||
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientConn, proxyConn := net.Pipe()
|
||||
defer func() { _ = clientConn.Close() }()
|
||||
|
||||
origStdin := os.Stdin
|
||||
origStdout := os.Stdout
|
||||
defer func() {
|
||||
os.Stdin = origStdin
|
||||
os.Stdout = origStdout
|
||||
}()
|
||||
|
||||
stdinReader, stdinWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
stdoutReader, stdoutWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
|
||||
os.Stdin = stdinReader
|
||||
os.Stdout = stdoutWriter
|
||||
|
||||
go func() {
|
||||
_, _ = io.Copy(stdinWriter, proxyConn)
|
||||
}()
|
||||
go func() {
|
||||
_, _ = io.Copy(proxyConn, stdoutReader)
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
connectErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
connectErrCh <- proxyInstance.Connect(ctx)
|
||||
}()
|
||||
|
||||
sshConfig := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: []cryptossh.AuthMethod{},
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
|
||||
require.NoError(t, err, "Should connect to proxy server")
|
||||
defer func() { _ = sshClientConn.Close() }()
|
||||
|
||||
sshClient := cryptossh.NewClient(sshClientConn, chans, reqs)
|
||||
|
||||
session, err := sshClient.NewSession()
|
||||
require.NoError(t, err, "Should create session through full proxy to backend")
|
||||
|
||||
outputCh := make(chan []byte, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
output, err := session.Output("echo hello-from-proxy")
|
||||
outputCh <- output
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case output := <-outputCh:
|
||||
err := <-errCh
|
||||
require.NoError(t, err, "Command should execute successfully through proxy")
|
||||
assert.Contains(t, string(output), "hello-from-proxy", "Should receive command output through proxy")
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("Command execution timed out")
|
||||
}
|
||||
|
||||
_ = session.Close()
|
||||
_ = sshClient.Close()
|
||||
_ = clientConn.Close()
|
||||
cancel()
|
||||
}
|
||||
|
||||
// TestSSHProxy_CommandQuoting verifies that the proxy preserves shell quoting
|
||||
// when forwarding commands to the backend. This is critical for tools like
|
||||
// Ansible that send commands such as:
|
||||
//
|
||||
// /bin/sh -c '( umask 77 && mkdir -p ... ) && sleep 0'
|
||||
//
|
||||
// The single quotes must be preserved so the backend shell receives the
|
||||
// subshell expression as a single argument to -c.
|
||||
func TestSSHProxy_CommandQuoting(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
sshClient, cleanup := setupProxySSHClient(t)
|
||||
defer cleanup()
|
||||
|
||||
// These commands simulate what the SSH protocol delivers as exec payloads.
|
||||
// When a user types: ssh host '/bin/sh -c "( echo hello )"'
|
||||
// the local shell strips the outer single quotes, and the SSH exec request
|
||||
// contains the raw string: /bin/sh -c "( echo hello )"
|
||||
//
|
||||
// The proxy must forward this string verbatim. Using session.Command()
|
||||
// (shlex.Split + strings.Join) strips the inner double quotes, breaking
|
||||
// the command on the backend.
|
||||
tests := []struct {
|
||||
name string
|
||||
command string
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "subshell_in_double_quotes",
|
||||
command: `/bin/sh -c "( echo from-subshell ) && echo outer"`,
|
||||
expect: "from-subshell\nouter\n",
|
||||
},
|
||||
{
|
||||
name: "printf_with_special_chars",
|
||||
command: `/bin/sh -c "printf '%s\n' 'hello world'"`,
|
||||
expect: "hello world\n",
|
||||
},
|
||||
{
|
||||
name: "nested_command_substitution",
|
||||
command: `/bin/sh -c "echo $(echo nested)"`,
|
||||
expect: "nested\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
session, err := sshClient.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = session.Close() }()
|
||||
|
||||
var stderrBuf bytes.Buffer
|
||||
session.Stderr = &stderrBuf
|
||||
|
||||
outputCh := make(chan []byte, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
output, err := session.Output(tc.command)
|
||||
outputCh <- output
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case output := <-outputCh:
|
||||
err := <-errCh
|
||||
if stderrBuf.Len() > 0 {
|
||||
t.Logf("stderr: %s", stderrBuf.String())
|
||||
}
|
||||
require.NoError(t, err, "command should succeed: %s", tc.command)
|
||||
assert.Equal(t, tc.expect, string(output), "output mismatch for: %s", tc.command)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatalf("command timed out: %s", tc.command)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// setupProxySSHClient creates a full proxy test environment and returns
|
||||
// an SSH client connected through the proxy to a backend NetBird SSH server.
|
||||
func setupProxySSHClient(t *testing.T) (*cryptossh.Client, func()) {
|
||||
t.Helper()
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
audience = "test-audience"
|
||||
)
|
||||
|
||||
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &server.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: &server.JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audiences: []string{audience},
|
||||
KeysLocation: jwksURL,
|
||||
},
|
||||
}
|
||||
sshServer := server.New(serverConfig)
|
||||
sshServer.SetAllowRootLogin(true)
|
||||
|
||||
testUsername := testutil.GetTestUsername(t)
|
||||
testJWTUser := "test-username"
|
||||
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
authConfig := &sshauth.Config{
|
||||
UserIDClaim: sshauth.DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
testUsername: {0},
|
||||
},
|
||||
}
|
||||
sshServer.UpdateSSHAuth(authConfig)
|
||||
|
||||
sshServerAddr := server.StartTestServer(t, sshServer)
|
||||
|
||||
mockDaemon := startMockDaemon(t)
|
||||
|
||||
host, portStr, err := net.SplitHostPort(sshServerAddr)
|
||||
require.NoError(t, err)
|
||||
port, err := strconv.Atoi(portStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockDaemon.setHostKey(host, hostPubKey)
|
||||
|
||||
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
|
||||
mockDaemon.setJWTToken(validToken)
|
||||
|
||||
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
origStdin := os.Stdin
|
||||
origStdout := os.Stdout
|
||||
|
||||
stdinReader, stdinWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
stdoutReader, stdoutWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
|
||||
os.Stdin = stdinReader
|
||||
os.Stdout = stdoutWriter
|
||||
|
||||
clientConn, proxyConn := net.Pipe()
|
||||
|
||||
go func() { _, _ = io.Copy(stdinWriter, proxyConn) }()
|
||||
go func() { _, _ = io.Copy(proxyConn, stdoutReader) }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
|
||||
go func() {
|
||||
_ = proxyInstance.Connect(ctx)
|
||||
}()
|
||||
|
||||
sshConfig := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: []cryptossh.AuthMethod{},
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
client := cryptossh.NewClient(sshClientConn, chans, reqs)
|
||||
|
||||
cleanupFn := func() {
|
||||
_ = client.Close()
|
||||
_ = clientConn.Close()
|
||||
cancel()
|
||||
os.Stdin = origStdin
|
||||
os.Stdout = origStdout
|
||||
_ = sshServer.Stop()
|
||||
mockDaemon.stop()
|
||||
jwksServer.Close()
|
||||
}
|
||||
|
||||
return client, cleanupFn
|
||||
}
|
||||
|
||||
type mockDaemonServer struct {
|
||||
proto.UnimplementedDaemonServiceServer
|
||||
hostKeys map[string][]byte
|
||||
@@ -492,10 +150,6 @@ func (m *mockDaemon) setHostKey(addr string, pubKey []byte) {
|
||||
m.impl.hostKeys[addr] = pubKey
|
||||
}
|
||||
|
||||
func (m *mockDaemon) setJWTToken(token string) {
|
||||
m.impl.jwtToken = token
|
||||
}
|
||||
|
||||
func (m *mockDaemon) stop() {
|
||||
if m.server != nil {
|
||||
m.server.Stop()
|
||||
@@ -508,63 +162,3 @@ func mustParsePublicKey(t *testing.T, pubKeyBytes []byte) cryptossh.PublicKey {
|
||||
require.NoError(t, err)
|
||||
return pubKey
|
||||
}
|
||||
|
||||
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
|
||||
t.Helper()
|
||||
privateKey, jwksJSON := generateTestJWKS(t)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if _, err := w.Write(jwksJSON); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}))
|
||||
|
||||
return server, privateKey, server.URL
|
||||
}
|
||||
|
||||
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
|
||||
t.Helper()
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
publicKey := &privateKey.PublicKey
|
||||
n := publicKey.N.Bytes()
|
||||
e := publicKey.E
|
||||
|
||||
jwk := nbjwt.JSONWebKey{
|
||||
Kty: "RSA",
|
||||
Kid: "test-key-id",
|
||||
Use: "sig",
|
||||
N: base64.RawURLEncoding.EncodeToString(n),
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(e)).Bytes()),
|
||||
}
|
||||
|
||||
jwks := nbjwt.Jwks{
|
||||
Keys: []nbjwt.JSONWebKey{jwk},
|
||||
}
|
||||
|
||||
jwksJSON, err := json.Marshal(jwks)
|
||||
require.NoError(t, err)
|
||||
|
||||
return privateKey, jwksJSON
|
||||
}
|
||||
|
||||
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string, user string) string {
|
||||
t.Helper()
|
||||
claims := jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": audience,
|
||||
"sub": user,
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
token.Header["kid"] = "test-key-id"
|
||||
|
||||
tokenString, err := token.SignedString(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
return tokenString
|
||||
}
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
|
||||
"github.com/creack/pty"
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/netbirdio/netbird/client/internal/shell"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@@ -147,10 +146,10 @@ func (s *Server) createShellCommand(ctx context.Context, shell string, args []st
|
||||
|
||||
// prepareCommandEnv prepares environment variables for command execution on Unix
|
||||
func (s *Server) prepareCommandEnv(_ *log.Entry, localUser *user.User, session ssh.Session) []string {
|
||||
env := shell.PrepareUserEnv(localUser, shell.GetUserShell(localUser.Uid))
|
||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||
env = append(env, prepareSSHEnv(session)...)
|
||||
for _, v := range session.Environ() {
|
||||
if shell.AcceptEnv(v) {
|
||||
if acceptEnv(v) {
|
||||
env = append(env, v)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -247,10 +247,10 @@ func (s *Server) prepareCommandEnv(logger *log.Entry, localUser *user.User, sess
|
||||
userEnv, err := s.getUserEnvironment(logger, username, domain)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err)
|
||||
env := shell.PrepareUserEnv(localUser, shell.GetUserShell(localUser.Uid))
|
||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||
env = append(env, prepareSSHEnv(session)...)
|
||||
for _, v := range session.Environ() {
|
||||
if shell.AcceptEnv(v) {
|
||||
if acceptEnv(v) {
|
||||
env = append(env, v)
|
||||
}
|
||||
}
|
||||
@@ -260,7 +260,7 @@ func (s *Server) prepareCommandEnv(logger *log.Entry, localUser *user.User, sess
|
||||
env := userEnv
|
||||
env = append(env, prepareSSHEnv(session)...)
|
||||
for _, v := range session.Environ() {
|
||||
if shell.AcceptEnv(v) {
|
||||
if acceptEnv(v) {
|
||||
env = append(env, v)
|
||||
}
|
||||
}
|
||||
@@ -273,7 +273,7 @@ func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, privileg
|
||||
return false
|
||||
}
|
||||
|
||||
shell := shell.GetUserShell(privilegeResult.User.Uid)
|
||||
shell := getUserShell(privilegeResult.User.Uid)
|
||||
logger.Infof("starting interactive shell: %s", shell)
|
||||
|
||||
s.executeCommandWithPty(logger, session, nil, privilegeResult, ptyReq, nil)
|
||||
@@ -384,7 +384,7 @@ func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, _
|
||||
}
|
||||
|
||||
username, domain := s.parseUsername(localUser.Username)
|
||||
shell := shell.GetUserShell(localUser.Uid)
|
||||
shell := getUserShell(localUser.Uid)
|
||||
|
||||
req := PtyExecutionRequest{
|
||||
Shell: shell,
|
||||
|
||||
66
client/ssh/server/executor_unix_privileged_test.go
Normal file
66
client/ssh/server/executor_unix_privileged_test.go
Normal file
@@ -0,0 +1,66 @@
|
||||
//go:build unix && privileged
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPrivilegeDropper_CreateExecutorCommand(t *testing.T) {
|
||||
pd := NewPrivilegeDropper()
|
||||
|
||||
config := ExecutorConfig{
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
Groups: []uint32{1000, 1001},
|
||||
WorkingDir: "/home/testuser",
|
||||
Shell: "/bin/bash",
|
||||
Command: "ls -la",
|
||||
}
|
||||
|
||||
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cmd)
|
||||
|
||||
// Verify the command is calling netbird ssh exec
|
||||
assert.Contains(t, cmd.Args, "ssh")
|
||||
assert.Contains(t, cmd.Args, "exec")
|
||||
assert.Contains(t, cmd.Args, "--uid")
|
||||
assert.Contains(t, cmd.Args, "1000")
|
||||
assert.Contains(t, cmd.Args, "--gid")
|
||||
assert.Contains(t, cmd.Args, "1000")
|
||||
assert.Contains(t, cmd.Args, "--groups")
|
||||
assert.Contains(t, cmd.Args, "1000")
|
||||
assert.Contains(t, cmd.Args, "1001")
|
||||
assert.Contains(t, cmd.Args, "--working-dir")
|
||||
assert.Contains(t, cmd.Args, "/home/testuser")
|
||||
assert.Contains(t, cmd.Args, "--shell")
|
||||
assert.Contains(t, cmd.Args, "/bin/bash")
|
||||
assert.Contains(t, cmd.Args, "--cmd")
|
||||
assert.Contains(t, cmd.Args, "ls -la")
|
||||
}
|
||||
|
||||
func TestPrivilegeDropper_CreateExecutorCommandInteractive(t *testing.T) {
|
||||
pd := NewPrivilegeDropper()
|
||||
|
||||
config := ExecutorConfig{
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
Groups: []uint32{1000},
|
||||
WorkingDir: "/home/testuser",
|
||||
Shell: "/bin/bash",
|
||||
Command: "",
|
||||
}
|
||||
|
||||
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cmd)
|
||||
|
||||
// Verify no command mode (command is empty so no --cmd flag)
|
||||
assert.NotContains(t, cmd.Args, "--cmd")
|
||||
assert.NotContains(t, cmd.Args, "--interactive")
|
||||
}
|
||||
@@ -73,61 +73,6 @@ func TestPrivilegeDropper_ValidatePrivileges(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrivilegeDropper_CreateExecutorCommand(t *testing.T) {
|
||||
pd := NewPrivilegeDropper()
|
||||
|
||||
config := ExecutorConfig{
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
Groups: []uint32{1000, 1001},
|
||||
WorkingDir: "/home/testuser",
|
||||
Shell: "/bin/bash",
|
||||
Command: "ls -la",
|
||||
}
|
||||
|
||||
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cmd)
|
||||
|
||||
// Verify the command is calling netbird ssh exec
|
||||
assert.Contains(t, cmd.Args, "ssh")
|
||||
assert.Contains(t, cmd.Args, "exec")
|
||||
assert.Contains(t, cmd.Args, "--uid")
|
||||
assert.Contains(t, cmd.Args, "1000")
|
||||
assert.Contains(t, cmd.Args, "--gid")
|
||||
assert.Contains(t, cmd.Args, "1000")
|
||||
assert.Contains(t, cmd.Args, "--groups")
|
||||
assert.Contains(t, cmd.Args, "1000")
|
||||
assert.Contains(t, cmd.Args, "1001")
|
||||
assert.Contains(t, cmd.Args, "--working-dir")
|
||||
assert.Contains(t, cmd.Args, "/home/testuser")
|
||||
assert.Contains(t, cmd.Args, "--shell")
|
||||
assert.Contains(t, cmd.Args, "/bin/bash")
|
||||
assert.Contains(t, cmd.Args, "--cmd")
|
||||
assert.Contains(t, cmd.Args, "ls -la")
|
||||
}
|
||||
|
||||
func TestPrivilegeDropper_CreateExecutorCommandInteractive(t *testing.T) {
|
||||
pd := NewPrivilegeDropper()
|
||||
|
||||
config := ExecutorConfig{
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
Groups: []uint32{1000},
|
||||
WorkingDir: "/home/testuser",
|
||||
Shell: "/bin/bash",
|
||||
Command: "",
|
||||
}
|
||||
|
||||
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cmd)
|
||||
|
||||
// Verify no command mode (command is empty so no --cmd flag)
|
||||
assert.NotContains(t, cmd.Args, "--cmd")
|
||||
assert.NotContains(t, cmd.Args, "--interactive")
|
||||
}
|
||||
|
||||
// TestPrivilegeDropper_ActualPrivilegeDrop tests actual privilege dropping
|
||||
// This test requires root privileges and will be skipped if not running as root
|
||||
func TestPrivilegeDropper_ActualPrivilegeDrop(t *testing.T) {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build cgo && !osusergo && !windows
|
||||
|
||||
package shell
|
||||
package server
|
||||
|
||||
import "os/user"
|
||||
|
||||
@@ -8,22 +8,17 @@ import "os/user"
|
||||
// When CGO is enabled, os/user uses libc (getpwnam_r) which goes through
|
||||
// the NSS stack natively. If it fails, the user truly doesn't exist and
|
||||
// getent would also fail.
|
||||
func LookupWithGetent(username string) (*user.User, error) {
|
||||
func lookupWithGetent(username string) (*user.User, error) {
|
||||
return user.Lookup(username)
|
||||
}
|
||||
|
||||
// currentUserWithGetent with CGO delegates directly to os/user.Current.
|
||||
func CurrentUserWithGetent() (*user.User, error) {
|
||||
func currentUserWithGetent() (*user.User, error) {
|
||||
return user.Current()
|
||||
}
|
||||
|
||||
// LookupGroupWithGetent returns the resolved group from either a gid or groupname
|
||||
func LookupGroupWithGetent(name string) (*user.Group, error) {
|
||||
return user.LookupGroup(name)
|
||||
}
|
||||
|
||||
// groupIdsWithFallback with CGO delegates directly to user.GroupIds.
|
||||
// libc's getgrouplist handles NSS groups natively.
|
||||
func GroupIdsWithFallback(u *user.User) ([]string, error) {
|
||||
func groupIdsWithFallback(u *user.User) ([]string, error) {
|
||||
return u.GroupIds()
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build (!cgo || osusergo) && !windows
|
||||
|
||||
package shell
|
||||
package server
|
||||
|
||||
import (
|
||||
"os"
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
// lookupWithGetent looks up a user by name, falling back to getent if os/user fails.
|
||||
// Without CGO, os/user only reads /etc/passwd and misses NSS-provided users.
|
||||
// getent goes through the host's NSS stack.
|
||||
func LookupWithGetent(username string) (*user.User, error) {
|
||||
func lookupWithGetent(username string) (*user.User, error) {
|
||||
u, err := user.Lookup(username)
|
||||
if err == nil {
|
||||
return u, nil
|
||||
@@ -22,7 +22,7 @@ func LookupWithGetent(username string) (*user.User, error) {
|
||||
stdErr := err
|
||||
log.Debugf("os/user.Lookup(%q) failed, trying getent: %v", username, err)
|
||||
|
||||
u, _, getentErr := runGetentPasswd(username)
|
||||
u, _, getentErr := runGetent(username)
|
||||
if getentErr != nil {
|
||||
log.Debugf("getent fallback for %q also failed: %v", username, getentErr)
|
||||
return nil, stdErr
|
||||
@@ -31,25 +31,8 @@ func LookupWithGetent(username string) (*user.User, error) {
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// LookupGroupWithGetent returns the resolved group from either a gid or groupname
|
||||
func LookupGroupWithGetent(name string) (*user.Group, error) {
|
||||
g, err := user.LookupGroup(name)
|
||||
if err == nil {
|
||||
return g, nil
|
||||
}
|
||||
|
||||
stdErr := err
|
||||
log.Debugf("os/user.LookupGroup(%q) failed, trying getent: %v", name, err)
|
||||
g, getentErr := runGetentGroup(name)
|
||||
if getentErr != nil {
|
||||
log.Debugf("getent fallback for %q also failed: %v", name, getentErr)
|
||||
return nil, stdErr
|
||||
}
|
||||
return g, nil
|
||||
}
|
||||
|
||||
// currentUserWithGetent gets the current user, falling back to getent if os/user fails.
|
||||
func CurrentUserWithGetent() (*user.User, error) {
|
||||
func currentUserWithGetent() (*user.User, error) {
|
||||
u, err := user.Current()
|
||||
if err == nil {
|
||||
return u, nil
|
||||
@@ -59,7 +42,7 @@ func CurrentUserWithGetent() (*user.User, error) {
|
||||
uid := strconv.Itoa(os.Getuid())
|
||||
log.Debugf("os/user.Current() failed, trying getent with UID %s: %v", uid, err)
|
||||
|
||||
u, _, getentErr := runGetentPasswd(uid)
|
||||
u, _, getentErr := runGetent(uid)
|
||||
if getentErr != nil {
|
||||
return nil, stdErr
|
||||
}
|
||||
@@ -74,7 +57,7 @@ func CurrentUserWithGetent() (*user.User, error) {
|
||||
// only reads /etc/group and silently returns incomplete results for NSS users
|
||||
// (no error, just missing groups). The id command goes through NSS and returns
|
||||
// the full set.
|
||||
func GroupIdsWithFallback(u *user.User) ([]string, error) {
|
||||
func groupIdsWithFallback(u *user.User) ([]string, error) {
|
||||
ids, err := runIdGroups(u.Username)
|
||||
if err == nil {
|
||||
return ids, nil
|
||||
@@ -1,4 +1,4 @@
|
||||
package shell
|
||||
package server
|
||||
|
||||
import (
|
||||
"os/user"
|
||||
@@ -15,7 +15,7 @@ func TestLookupWithGetent_CurrentUser(t *testing.T) {
|
||||
current, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
u, err := LookupWithGetent(current.Username)
|
||||
u, err := lookupWithGetent(current.Username)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, current.Username, u.Username)
|
||||
assert.Equal(t, current.Uid, u.Uid)
|
||||
@@ -23,7 +23,7 @@ func TestLookupWithGetent_CurrentUser(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestLookupWithGetent_NonexistentUser(t *testing.T) {
|
||||
_, err := LookupWithGetent("nonexistent_user_xyzzy_12345")
|
||||
_, err := lookupWithGetent("nonexistent_user_xyzzy_12345")
|
||||
require.Error(t, err, "should fail for nonexistent user")
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ func TestCurrentUserWithGetent(t *testing.T) {
|
||||
stdUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
u, err := CurrentUserWithGetent()
|
||||
u, err := currentUserWithGetent()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, stdUser.Uid, u.Uid)
|
||||
assert.Equal(t, stdUser.Username, u.Username)
|
||||
@@ -41,7 +41,7 @@ func TestGroupIdsWithFallback_CurrentUser(t *testing.T) {
|
||||
current, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
groups, err := GroupIdsWithFallback(current)
|
||||
groups, err := groupIdsWithFallback(current)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, groups, "current user should have at least one group")
|
||||
|
||||
@@ -56,7 +56,7 @@ func TestGroupIdsWithFallback_CurrentUser(t *testing.T) {
|
||||
func TestGetShellFromGetent_CurrentUser(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
// Windows stub always returns empty, which is correct
|
||||
shell := GetShellFromGetent("1000")
|
||||
shell := getShellFromGetent("1000")
|
||||
assert.Empty(t, shell, "Windows stub should return empty")
|
||||
return
|
||||
}
|
||||
@@ -65,7 +65,7 @@ func TestGetShellFromGetent_CurrentUser(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// getent may not be available on all systems (e.g., macOS without Homebrew getent)
|
||||
shell := GetShellFromGetent(current.Uid)
|
||||
shell := getShellFromGetent(current.Uid)
|
||||
if shell == "" {
|
||||
t.Log("getShellFromGetent returned empty, getent may not be available")
|
||||
return
|
||||
@@ -78,7 +78,7 @@ func TestLookupWithGetent_RootUser(t *testing.T) {
|
||||
t.Skip("no root user on Windows")
|
||||
}
|
||||
|
||||
u, err := LookupWithGetent("root")
|
||||
u, err := lookupWithGetent("root")
|
||||
if err != nil {
|
||||
t.Skip("root user not available on this system")
|
||||
}
|
||||
@@ -91,20 +91,20 @@ func TestLookupWithGetent_RootUser(t *testing.T) {
|
||||
// consistent and correct results when composed together.
|
||||
func TestIntegration_FullLookupChain(t *testing.T) {
|
||||
// Step 1: currentUserWithGetent must resolve the running user.
|
||||
current, err := CurrentUserWithGetent()
|
||||
current, err := currentUserWithGetent()
|
||||
require.NoError(t, err, "currentUserWithGetent must resolve the running user")
|
||||
require.NotEmpty(t, current.Uid)
|
||||
require.NotEmpty(t, current.Username)
|
||||
|
||||
// Step 2: lookupWithGetent by the same username must return matching identity.
|
||||
byName, err := LookupWithGetent(current.Username)
|
||||
byName, err := lookupWithGetent(current.Username)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, current.Uid, byName.Uid, "lookup by name should return same UID")
|
||||
assert.Equal(t, current.Gid, byName.Gid, "lookup by name should return same GID")
|
||||
assert.Equal(t, current.HomeDir, byName.HomeDir, "lookup by name should return same home")
|
||||
|
||||
// Step 3: groupIdsWithFallback must return at least the primary GID.
|
||||
groups, err := GroupIdsWithFallback(current)
|
||||
groups, err := groupIdsWithFallback(current)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, groups, "user must have at least one group")
|
||||
|
||||
@@ -123,7 +123,7 @@ func TestIntegration_FullLookupChain(t *testing.T) {
|
||||
// Step 4: getShellFromGetent should either return a valid shell path or empty
|
||||
// (empty is OK when getent is not available, e.g. macOS without Homebrew getent).
|
||||
if runtime.GOOS != "windows" {
|
||||
shell := GetShellFromGetent(current.Uid)
|
||||
shell := getShellFromGetent(current.Uid)
|
||||
if shell != "" {
|
||||
assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell)
|
||||
}
|
||||
@@ -138,10 +138,10 @@ func TestIntegration_LookupAndGroupsConsistency(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate the SSH server flow: lookup user, then get their groups.
|
||||
resolved, err := LookupWithGetent(current.Username)
|
||||
resolved, err := lookupWithGetent(current.Username)
|
||||
require.NoError(t, err)
|
||||
|
||||
groups, err := GroupIdsWithFallback(resolved)
|
||||
groups, err := groupIdsWithFallback(resolved)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, groups, "resolved user must have groups")
|
||||
|
||||
@@ -154,3 +154,19 @@ func TestIntegration_LookupAndGroupsConsistency(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegration_ShellLookupChain tests the full shell resolution chain
|
||||
// (getShellFromPasswd -> getShellFromGetent -> $SHELL -> default) on Unix.
|
||||
func TestIntegration_ShellLookupChain(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Unix shell lookup not applicable on Windows")
|
||||
}
|
||||
|
||||
current, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
// getUserShell is the top-level function used by the SSH server.
|
||||
shell := getUserShell(current.Uid)
|
||||
require.NotEmpty(t, shell, "getUserShell must always return a shell")
|
||||
assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell)
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build !windows
|
||||
|
||||
package shell
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -14,25 +14,19 @@ import (
|
||||
|
||||
const getentTimeout = 5 * time.Second
|
||||
|
||||
// GetShellFromGetent gets a user's login shell via getent by UID.
|
||||
// getShellFromGetent gets a user's login shell via getent by UID.
|
||||
// This is needed even with CGO because getShellFromPasswd reads /etc/passwd
|
||||
// directly and won't find NSS-provided users there.
|
||||
func GetShellFromGetent(userID string) string {
|
||||
_, shell, err := runGetentPasswd(userID)
|
||||
func getShellFromGetent(userID string) string {
|
||||
_, shell, err := runGetent(userID)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return shell
|
||||
}
|
||||
|
||||
// GetUserFromGetent returns the resolved group from either a uid or username
|
||||
func GetUserFromGetent(user string) (*user.User, error) {
|
||||
u, _, err := runGetentPasswd(user)
|
||||
return u, err
|
||||
}
|
||||
|
||||
// runGetentPasswd executes `getent passwd <query>` and returns the user and login shell.
|
||||
func runGetentPasswd(query string) (*user.User, string, error) {
|
||||
// runGetent executes `getent passwd <query>` and returns the user and login shell.
|
||||
func runGetent(query string) (*user.User, string, error) {
|
||||
if !validateGetentInput(query) {
|
||||
return nil, "", fmt.Errorf("invalid getent input: %q", query)
|
||||
}
|
||||
@@ -48,23 +42,6 @@ func runGetentPasswd(query string) (*user.User, string, error) {
|
||||
return parseGetentPasswd(string(out))
|
||||
}
|
||||
|
||||
// runGetentGroup executes `getent group <query>` and returns the group
|
||||
func runGetentGroup(query string) (*user.Group, error) {
|
||||
if !validateGetentInput(query) {
|
||||
return nil, fmt.Errorf("invalid getent input: %q", query)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), getentTimeout)
|
||||
defer cancel()
|
||||
|
||||
out, err := exec.CommandContext(ctx, "getent", "group", query).Output()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getent passwd%s: %w", query, err)
|
||||
}
|
||||
|
||||
return parseGetentGroup(string(out))
|
||||
}
|
||||
|
||||
// parseGetentPasswd parses getent passwd output: "name:x:uid:gid:gecos:home:shell"
|
||||
func parseGetentPasswd(output string) (*user.User, string, error) {
|
||||
fields := strings.SplitN(strings.TrimSpace(output), ":", 8)
|
||||
@@ -90,20 +67,6 @@ func parseGetentPasswd(output string) (*user.User, string, error) {
|
||||
}, shell, nil
|
||||
}
|
||||
|
||||
// parseGetentGroup parses getent group output: "group:x:gid:user"
|
||||
func parseGetentGroup(output string) (*user.Group, error) {
|
||||
fields := strings.SplitN(strings.TrimSpace(output), ":", 8)
|
||||
if len(fields) < 4 {
|
||||
return nil, fmt.Errorf("unexpected getent output (need 4+ fields): %q", output)
|
||||
}
|
||||
|
||||
if fields[0] == "" || fields[2] == "" {
|
||||
return nil, fmt.Errorf("missing required fields in getent output: %q", output)
|
||||
}
|
||||
|
||||
return &user.Group{Gid: fields[2], Name: fields[0]}, nil
|
||||
}
|
||||
|
||||
// validateGetentInput checks that the input is safe to pass to getent or id.
|
||||
// Allows POSIX usernames, numeric UIDs, and common NSS extensions
|
||||
// (@ for Kerberos, $ for Samba, + for NIS compat).
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build !windows
|
||||
|
||||
package shell
|
||||
package server
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
@@ -195,7 +195,7 @@ func TestRunGetent_RootUser(t *testing.T) {
|
||||
t.Skip("getent not available on this system")
|
||||
}
|
||||
|
||||
u, shell, err := runGetentPasswd("root")
|
||||
u, shell, err := runGetent("root")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "root", u.Username)
|
||||
assert.Equal(t, "0", u.Uid)
|
||||
@@ -208,7 +208,7 @@ func TestRunGetent_ByUID(t *testing.T) {
|
||||
t.Skip("getent not available on this system")
|
||||
}
|
||||
|
||||
u, _, err := runGetentPasswd("0")
|
||||
u, _, err := runGetent("0")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "root", u.Username)
|
||||
assert.Equal(t, "0", u.Uid)
|
||||
@@ -219,15 +219,15 @@ func TestRunGetent_NonexistentUser(t *testing.T) {
|
||||
t.Skip("getent not available on this system")
|
||||
}
|
||||
|
||||
_, _, err := runGetentPasswd("nonexistent_user_xyzzy_12345")
|
||||
_, _, err := runGetent("nonexistent_user_xyzzy_12345")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestRunGetent_InvalidInput(t *testing.T) {
|
||||
_, _, err := runGetentPasswd("")
|
||||
_, _, err := runGetent("")
|
||||
assert.Error(t, err)
|
||||
|
||||
_, _, err = runGetentPasswd("user\x00name")
|
||||
_, _, err = runGetent("user\x00name")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
@@ -236,7 +236,7 @@ func TestRunGetent_NotAvailable(t *testing.T) {
|
||||
t.Skip("getent is available, can't test missing case")
|
||||
}
|
||||
|
||||
_, _, err := runGetentPasswd("root")
|
||||
_, _, err := runGetent("root")
|
||||
assert.Error(t, err, "should fail when getent is not installed")
|
||||
}
|
||||
|
||||
@@ -283,7 +283,7 @@ func TestGetentResultsMatchStdlib(t *testing.T) {
|
||||
current, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
getentUser, _, err := runGetentPasswd(current.Username)
|
||||
getentUser, _, err := runGetent(current.Username)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, current.Username, getentUser.Username, "username should match")
|
||||
@@ -300,7 +300,7 @@ func TestGetentResultsMatchStdlib_ByUID(t *testing.T) {
|
||||
current, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
getentUser, _, err := runGetentPasswd(current.Uid)
|
||||
getentUser, _, err := runGetent(current.Uid)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, current.Username, getentUser.Username, "username should match when looked up by UID")
|
||||
@@ -356,7 +356,7 @@ func TestGetShellFromPasswd_CurrentUser(t *testing.T) {
|
||||
assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell)
|
||||
|
||||
if _, err := exec.LookPath("getent"); err == nil {
|
||||
_, getentShell, getentErr := runGetentPasswd(current.Uid)
|
||||
_, getentShell, getentErr := runGetent(current.Uid)
|
||||
if getentErr == nil && getentShell != "" {
|
||||
assert.Equal(t, getentShell, shell, "shell from /etc/passwd should match getent")
|
||||
}
|
||||
@@ -400,7 +400,7 @@ func TestGetShellFromPasswd_MatchesGetentForKnownUsers(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
|
||||
_, getentShell, err := runGetentPasswd(uid)
|
||||
_, getentShell, err := runGetent(uid)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
@@ -1,26 +1,26 @@
|
||||
//go:build windows
|
||||
|
||||
package shell
|
||||
package server
|
||||
|
||||
import "os/user"
|
||||
|
||||
// lookupWithGetent on Windows just delegates to os/user.Lookup.
|
||||
// Windows does not use NSS/getent; its user lookup works without CGO.
|
||||
func LookupWithGetent(username string) (*user.User, error) {
|
||||
func lookupWithGetent(username string) (*user.User, error) {
|
||||
return user.Lookup(username)
|
||||
}
|
||||
|
||||
// currentUserWithGetent on Windows just delegates to os/user.Current.
|
||||
func CurrentUserWithGetent() (*user.User, error) {
|
||||
func currentUserWithGetent() (*user.User, error) {
|
||||
return user.Current()
|
||||
}
|
||||
|
||||
// getShellFromGetent is a no-op on Windows; shell resolution uses PowerShell detection.
|
||||
func GetShellFromGetent(_ string) string {
|
||||
func getShellFromGetent(_ string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// groupIdsWithFallback on Windows just delegates to u.GroupIds().
|
||||
func GroupIdsWithFallback(u *user.User) ([]string, error) {
|
||||
func groupIdsWithFallback(u *user.User) ([]string, error) {
|
||||
return u.GroupIds()
|
||||
}
|
||||
@@ -1,14 +1,17 @@
|
||||
package shell
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@@ -21,7 +24,7 @@ const (
|
||||
|
||||
// getUserShell returns the appropriate shell for the given user ID
|
||||
// Handles all platform-specific logic and fallbacks consistently
|
||||
func GetUserShell(userID string) string {
|
||||
func getUserShell(userID string) string {
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
return getWindowsUserShell()
|
||||
@@ -53,7 +56,7 @@ func getUnixUserShell(userID string) string {
|
||||
return shell
|
||||
}
|
||||
|
||||
if shell := GetShellFromGetent(userID); shell != "" {
|
||||
if shell := getShellFromGetent(userID); shell != "" {
|
||||
return shell
|
||||
}
|
||||
|
||||
@@ -98,8 +101,8 @@ func getShellFromPasswd(userID string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// PrepareUserEnv prepares environment variables for user execution
|
||||
func PrepareUserEnv(user *user.User, shell string) []string {
|
||||
// prepareUserEnv prepares environment variables for user execution
|
||||
func prepareUserEnv(user *user.User, shell string) []string {
|
||||
pathValue := "/usr/local/bin:/usr/bin:/bin:/usr/local/games:/usr/games"
|
||||
if runtime.GOOS == "windows" {
|
||||
pathValue = `C:\Windows\System32;C:\Windows;C:\Windows\System32\Wbem;C:\Windows\System32\WindowsPowerShell\v1.0`
|
||||
@@ -116,7 +119,7 @@ func PrepareUserEnv(user *user.User, shell string) []string {
|
||||
|
||||
// acceptEnv checks if environment variable from SSH client should be accepted
|
||||
// This is a whitelist of variables that SSH clients can send to the server
|
||||
func AcceptEnv(envVar string) bool {
|
||||
func acceptEnv(envVar string) bool {
|
||||
varName := envVar
|
||||
if idx := strings.Index(envVar, "="); idx != -1 {
|
||||
varName = envVar[:idx]
|
||||
@@ -153,3 +156,29 @@ func AcceptEnv(envVar string) bool {
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// prepareSSHEnv prepares SSH protocol-specific environment variables
|
||||
// These variables provide information about the SSH connection itself
|
||||
func prepareSSHEnv(session ssh.Session) []string {
|
||||
remoteAddr := session.RemoteAddr()
|
||||
localAddr := session.LocalAddr()
|
||||
|
||||
remoteHost, remotePort, err := net.SplitHostPort(remoteAddr.String())
|
||||
if err != nil {
|
||||
remoteHost = remoteAddr.String()
|
||||
remotePort = "0"
|
||||
}
|
||||
|
||||
localHost, localPort, err := net.SplitHostPort(localAddr.String())
|
||||
if err != nil {
|
||||
localHost = localAddr.String()
|
||||
localPort = strconv.Itoa(InternalSSHPort)
|
||||
}
|
||||
|
||||
return []string{
|
||||
// SSH_CLIENT format: "client_ip client_port server_port"
|
||||
fmt.Sprintf("SSH_CLIENT=%s %s %s", remoteHost, remotePort, localPort),
|
||||
// SSH_CONNECTION format: "client_ip client_port server_ip server_port"
|
||||
fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", remoteHost, remotePort, localHost, localPort),
|
||||
}
|
||||
}
|
||||
@@ -3,15 +3,11 @@ package server
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/netbirdio/netbird/client/internal/shell"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@@ -27,8 +23,8 @@ func isPlatformUnix() bool {
|
||||
|
||||
// Dependency injection variables for testing - allows mocking dynamic runtime checks
|
||||
var (
|
||||
getCurrentUser = shell.CurrentUserWithGetent
|
||||
lookupUser = shell.LookupWithGetent
|
||||
getCurrentUser = currentUserWithGetent
|
||||
lookupUser = lookupWithGetent
|
||||
getCurrentOS = func() string { return runtime.GOOS }
|
||||
getIsProcessPrivileged = isCurrentProcessPrivileged
|
||||
|
||||
@@ -413,29 +409,3 @@ func isWindowsElevated() bool {
|
||||
log.Debugf("Windows user switching not supported: not running as privileged user (current: %s)", currentUser.Uid)
|
||||
return false
|
||||
}
|
||||
|
||||
// prepareSSHEnv prepares SSH protocol-specific environment variables
|
||||
// These variables provide information about the SSH connection itself
|
||||
func prepareSSHEnv(session ssh.Session) []string {
|
||||
remoteAddr := session.RemoteAddr()
|
||||
localAddr := session.LocalAddr()
|
||||
|
||||
remoteHost, remotePort, err := net.SplitHostPort(remoteAddr.String())
|
||||
if err != nil {
|
||||
remoteHost = remoteAddr.String()
|
||||
remotePort = "0"
|
||||
}
|
||||
|
||||
localHost, localPort, err := net.SplitHostPort(localAddr.String())
|
||||
if err != nil {
|
||||
localHost = localAddr.String()
|
||||
localPort = strconv.Itoa(InternalSSHPort)
|
||||
}
|
||||
|
||||
return []string{
|
||||
// SSH_CLIENT format: "client_ip client_port server_port"
|
||||
fmt.Sprintf("SSH_CLIENT=%s %s %s", remoteHost, remotePort, localPort),
|
||||
// SSH_CONNECTION format: "client_ip client_port server_ip server_port"
|
||||
fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", remoteHost, remotePort, localHost, localPort),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"strconv"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/netbirdio/netbird/client/internal/shell"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@@ -161,7 +160,7 @@ func (s *Server) parseUserCredentials(localUser *user.User) (uint32, uint32, []u
|
||||
// getSupplementaryGroups retrieves supplementary group IDs for a user.
|
||||
// Uses id/getent fallback for NSS users in CGO_ENABLED=0 builds.
|
||||
func (s *Server) getSupplementaryGroups(u *user.User) ([]uint32, error) {
|
||||
groupIDStrings, err := shell.GroupIdsWithFallback(u)
|
||||
groupIDStrings, err := groupIdsWithFallback(u)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group IDs for user %s: %w", u.Username, err)
|
||||
}
|
||||
@@ -197,7 +196,7 @@ func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, l
|
||||
GID: gid,
|
||||
Groups: groups,
|
||||
WorkingDir: localUser.HomeDir,
|
||||
Shell: shell.GetUserShell(localUser.Uid),
|
||||
Shell: getUserShell(localUser.Uid),
|
||||
Command: session.RawCommand(),
|
||||
PTY: hasPty,
|
||||
}
|
||||
@@ -229,7 +228,7 @@ func (s *Server) createPtyCommand(privilegeResult PrivilegeCheckResult, ptyReq s
|
||||
func (s *Server) createDirectPtyCommand(session ssh.Session, localUser *user.User, ptyReq ssh.Pty) *exec.Cmd {
|
||||
log.Debugf("creating direct Pty command for user %s (no user switching needed)", localUser.Username)
|
||||
|
||||
shell := shell.GetUserShell(localUser.Uid)
|
||||
shell := getUserShell(localUser.Uid)
|
||||
args := s.getShellCommandArgs(shell, session.RawCommand())
|
||||
|
||||
cmd := s.createShellCommand(session.Context(), shell, args)
|
||||
@@ -246,12 +245,12 @@ func (s *Server) preparePtyEnv(localUser *user.User, ptyReq ssh.Pty, session ssh
|
||||
termType = "xterm-256color"
|
||||
}
|
||||
|
||||
env := shell.PrepareUserEnv(localUser, shell.GetUserShell(localUser.Uid))
|
||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||
env = append(env, prepareSSHEnv(session)...)
|
||||
env = append(env, fmt.Sprintf("TERM=%s", termType))
|
||||
|
||||
for _, v := range session.Environ() {
|
||||
if shell.AcceptEnv(v) {
|
||||
if acceptEnv(v) {
|
||||
env = append(env, v)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,8 +13,6 @@ import (
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/shell"
|
||||
)
|
||||
|
||||
// validateUsername validates Windows usernames according to SAM Account Name rules
|
||||
@@ -106,7 +104,7 @@ func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, l
|
||||
func (s *Server) createUserSwitchCommand(logger *log.Entry, session ssh.Session, localUser *user.User) (*exec.Cmd, func(), error) {
|
||||
username, domain := s.parseUsername(localUser.Username)
|
||||
|
||||
sh := shell.GetUserShell(localUser.Uid)
|
||||
shell := getUserShell(localUser.Uid)
|
||||
|
||||
rawCmd := session.RawCommand()
|
||||
var command string
|
||||
@@ -118,7 +116,7 @@ func (s *Server) createUserSwitchCommand(logger *log.Entry, session ssh.Session,
|
||||
Username: username,
|
||||
Domain: domain,
|
||||
WorkingDir: localUser.HomeDir,
|
||||
Shell: sh,
|
||||
Shell: shell,
|
||||
Command: command,
|
||||
}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user