diff --git a/.github/workflows/golang-test-freebsd.yml b/.github/workflows/golang-test-freebsd.yml index b03313bbd..0d19e8a19 100644 --- a/.github/workflows/golang-test-freebsd.yml +++ b/.github/workflows/golang-test-freebsd.yml @@ -39,7 +39,7 @@ jobs: # check all component except management, since we do not support management server on freebsd time go test -timeout 1m -failfast ./base62/... # NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use` - time go test -timeout 8m -failfast -p 1 ./client/... + time go test -timeout 8m -failfast -v -p 1 ./client/... time go test -timeout 1m -failfast ./dns/... time go test -timeout 1m -failfast ./encryption/... time go test -timeout 1m -failfast ./formatter/... diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index a9bc1b979..2fa847dce 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -9,7 +9,7 @@ on: pull_request: env: - SIGN_PIPE_VER: "v0.0.23" + SIGN_PIPE_VER: "v0.1.0" GORELEASER_VER: "v2.3.2" PRODUCT_NAME: "NetBird" COPYRIGHT: "NetBird GmbH" @@ -19,6 +19,100 @@ concurrency: cancel-in-progress: true jobs: + release_freebsd_port: + name: "FreeBSD Port / Build & Test" + runs-on: ubuntu-22.04 + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Generate FreeBSD port diff + run: bash release_files/freebsd-port-diff.sh + + - name: Generate FreeBSD port issue body + run: bash release_files/freebsd-port-issue-body.sh + + - name: Check if diff was generated + id: check_diff + run: | + if ls netbird-*.diff 1> /dev/null 2>&1; then + echo "diff_exists=true" >> $GITHUB_OUTPUT + else + echo "diff_exists=false" >> $GITHUB_OUTPUT + echo "No diff file generated (port may already be up to date)" + fi + + - name: Extract version + if: steps.check_diff.outputs.diff_exists == 'true' + id: version + run: | + VERSION=$(ls netbird-*.diff | sed 's/netbird-\(.*\)\.diff/\1/') + echo "version=$VERSION" >> $GITHUB_OUTPUT + echo "Generated files for version: $VERSION" + cat netbird-*.diff + + - name: Test FreeBSD port + if: steps.check_diff.outputs.diff_exists == 'true' + uses: vmactions/freebsd-vm@v1 + with: + usesh: true + copyback: false + release: "15.0" + prepare: | + # Install required packages + pkg install -y git curl portlint go + + # Install Go for building + GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz" + GO_URL="https://go.dev/dl/$GO_TARBALL" + curl -LO "$GO_URL" + tar -C /usr/local -xzf "$GO_TARBALL" + + # Clone ports tree (shallow, only what we need) + git clone --depth 1 --filter=blob:none https://git.FreeBSD.org/ports.git /usr/ports + cd /usr/ports + + run: | + set -e -x + export PATH=$PATH:/usr/local/go/bin + + # Find the diff file + echo "Finding diff file..." + DIFF_FILE=$(find $PWD -name "netbird-*.diff" -type f 2>/dev/null | head -1) + echo "Found: $DIFF_FILE" + + if [[ -z "$DIFF_FILE" ]]; then + echo "ERROR: Could not find diff file" + find ~ -name "*.diff" -type f 2>/dev/null || true + exit 1 + fi + + # Apply the generated diff from /usr/ports (diff has a/security/netbird/... paths) + cd /usr/ports + patch -p1 -V none < "$DIFF_FILE" + + # Show patched Makefile + version=$(cat security/netbird/Makefile | grep -E '^DISTVERSION=' | awk '{print $NF}') + + cd /usr/ports/security/netbird + export BATCH=yes + make package + pkg add ./work/pkg/netbird-*.pkg + + netbird version | grep "$version" + + echo "FreeBSD port test completed successfully!" + + - name: Upload FreeBSD port files + if: steps.check_diff.outputs.diff_exists == 'true' + uses: actions/upload-artifact@v4 + with: + name: freebsd-port-files + path: | + ./netbird-*-issue.txt + ./netbird-*.diff + retention-days: 30 + release: runs-on: ubuntu-latest-m env: diff --git a/.github/workflows/test-infrastructure-files.yml b/.github/workflows/test-infrastructure-files.yml index f4513e0e1..e2f950731 100644 --- a/.github/workflows/test-infrastructure-files.yml +++ b/.github/workflows/test-infrastructure-files.yml @@ -243,6 +243,7 @@ jobs: working-directory: infrastructure_files/artifacts run: | sleep 30 + docker compose logs docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City_[0-9]*.mmdb docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames_[0-9]*.db diff --git a/.gitignore b/.gitignore index e6c0c0aca..89024d190 100644 --- a/.gitignore +++ b/.gitignore @@ -31,3 +31,4 @@ infrastructure_files/setup-*.env .DS_Store vendor/ /netbird +client/netbird-electron/ diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 952e946dc..7c6651f83 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -713,8 +713,10 @@ checksum: extra_files: - glob: ./infrastructure_files/getting-started-with-zitadel.sh - glob: ./release_files/install.sh + - glob: ./infrastructure_files/getting-started.sh release: extra_files: - glob: ./infrastructure_files/getting-started-with-zitadel.sh - glob: ./release_files/install.sh + - glob: ./infrastructure_files/getting-started.sh diff --git a/README.md b/README.md index 2c5ee2ab6..28b53d5b6 100644 --- a/README.md +++ b/README.md @@ -85,7 +85,7 @@ Follow the [Advanced guide with a custom identity provider](https://docs.netbird **Infrastructure requirements:** - A Linux VM with at least **1CPU** and **2GB** of memory. -- The VM should be publicly accessible on TCP ports **80** and **443** and UDP ports: **3478**, **49152-65535**. +- The VM should be publicly accessible on TCP ports **80** and **443** and UDP port: **3478**. - **Public domain** name pointing to the VM. **Software requirements:** @@ -98,7 +98,7 @@ Follow the [Advanced guide with a custom identity provider](https://docs.netbird **Steps** - Download and run the installation script: ```bash -export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started-with-zitadel.sh | bash +export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started.sh | bash ``` - Once finished, you can manage the resources via `docker-compose` @@ -113,7 +113,7 @@ export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbird [Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.

- +

See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details. diff --git a/client/android/client.go b/client/android/client.go index 0d5474c4b..ccf32a90c 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -59,7 +59,6 @@ func init() { // Client struct manage the life circle of background service type Client struct { - cfgFile string tunAdapter device.TunAdapter iFaceDiscover IFaceDiscover recorder *peer.Status @@ -68,18 +67,16 @@ type Client struct { deviceName string uiVersion string networkChangeListener listener.NetworkChangeListener - stateFile string connectClient *internal.ConnectClient } // NewClient instantiate a new Client -func NewClient(platformFiles PlatformFiles, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client { +func NewClient(androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client { execWorkaround(androidSDKVersion) net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket) return &Client{ - cfgFile: platformFiles.ConfigurationFilePath(), deviceName: deviceName, uiVersion: uiVersion, tunAdapter: tunAdapter, @@ -87,15 +84,20 @@ func NewClient(platformFiles PlatformFiles, androidSDKVersion int, deviceName st recorder: peer.NewRecorder(""), ctxCancelLock: &sync.Mutex{}, networkChangeListener: networkChangeListener, - stateFile: platformFiles.StateFilePath(), } } // Run start the internal client. It is a blocker function -func (c *Client) Run(urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error { +func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error { exportEnvList(envList) + + cfgFile := platformFiles.ConfigurationFilePath() + stateFile := platformFiles.StateFilePath() + + log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile) + cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ - ConfigPath: c.cfgFile, + ConfigPath: cfgFile, }) if err != nil { return err @@ -122,16 +124,22 @@ func (c *Client) Run(urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsRea // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) - c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) - return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile) + c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false) + return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile) } // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). // In this case make no sense handle registration steps. -func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error { +func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error { exportEnvList(envList) + + cfgFile := platformFiles.ConfigurationFilePath() + stateFile := platformFiles.StateFilePath() + + log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile) + cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ - ConfigPath: c.cfgFile, + ConfigPath: cfgFile, }) if err != nil { return err @@ -149,8 +157,8 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) - c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) - return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile) + c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false) + return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile) } // Stop the internal client and free the resources diff --git a/client/android/profile_manager.go b/client/android/profile_manager.go new file mode 100644 index 000000000..60e4d5c32 --- /dev/null +++ b/client/android/profile_manager.go @@ -0,0 +1,257 @@ +//go:build android + +package android + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/profilemanager" +) + +const ( + // Android-specific config filename (different from desktop default.json) + defaultConfigFilename = "netbird.cfg" + // Subdirectory for non-default profiles (must match Java Preferences.java) + profilesSubdir = "profiles" + // Android uses a single user context per app (non-empty username required by ServiceManager) + androidUsername = "android" +) + +// Profile represents a profile for gomobile +type Profile struct { + Name string + IsActive bool +} + +// ProfileArray wraps profiles for gomobile compatibility +type ProfileArray struct { + items []*Profile +} + +// Length returns the number of profiles +func (p *ProfileArray) Length() int { + return len(p.items) +} + +// Get returns the profile at index i +func (p *ProfileArray) Get(i int) *Profile { + if i < 0 || i >= len(p.items) { + return nil + } + return p.items[i] +} + +/* + +/data/data/io.netbird.client/files/ ← configDir parameter +├── netbird.cfg ← Default profile config +├── state.json ← Default profile state +├── active_profile.json ← Active profile tracker (JSON with Name + Username) +└── profiles/ ← Subdirectory for non-default profiles + ├── work.json ← Work profile config + ├── work.state.json ← Work profile state + ├── personal.json ← Personal profile config + └── personal.state.json ← Personal profile state +*/ + +// ProfileManager manages profiles for Android +// It wraps the internal profilemanager to provide Android-specific behavior +type ProfileManager struct { + configDir string + serviceMgr *profilemanager.ServiceManager +} + +// NewProfileManager creates a new profile manager for Android +func NewProfileManager(configDir string) *ProfileManager { + // Set the default config path for Android (stored in root configDir, not profiles/) + defaultConfigPath := filepath.Join(configDir, defaultConfigFilename) + + // Set global paths for Android + profilemanager.DefaultConfigPathDir = configDir + profilemanager.DefaultConfigPath = defaultConfigPath + profilemanager.ActiveProfileStatePath = filepath.Join(configDir, "active_profile.json") + + // Create ServiceManager with profiles/ subdirectory + // This avoids modifying the global ConfigDirOverride for profile listing + profilesDir := filepath.Join(configDir, profilesSubdir) + serviceMgr := profilemanager.NewServiceManagerWithProfilesDir(defaultConfigPath, profilesDir) + + return &ProfileManager{ + configDir: configDir, + serviceMgr: serviceMgr, + } +} + +// ListProfiles returns all available profiles +func (pm *ProfileManager) ListProfiles() (*ProfileArray, error) { + // Use ServiceManager (looks in profiles/ directory, checks active_profile.json for IsActive) + internalProfiles, err := pm.serviceMgr.ListProfiles(androidUsername) + if err != nil { + return nil, fmt.Errorf("failed to list profiles: %w", err) + } + + // Convert internal profiles to Android Profile type + var profiles []*Profile + for _, p := range internalProfiles { + profiles = append(profiles, &Profile{ + Name: p.Name, + IsActive: p.IsActive, + }) + } + + return &ProfileArray{items: profiles}, nil +} + +// GetActiveProfile returns the currently active profile name +func (pm *ProfileManager) GetActiveProfile() (string, error) { + // Use ServiceManager to stay consistent with ListProfiles + // ServiceManager uses active_profile.json + activeState, err := pm.serviceMgr.GetActiveProfileState() + if err != nil { + return "", fmt.Errorf("failed to get active profile: %w", err) + } + return activeState.Name, nil +} + +// SwitchProfile switches to a different profile +func (pm *ProfileManager) SwitchProfile(profileName string) error { + // Use ServiceManager to stay consistent with ListProfiles + // ServiceManager uses active_profile.json + err := pm.serviceMgr.SetActiveProfileState(&profilemanager.ActiveProfileState{ + Name: profileName, + Username: androidUsername, + }) + if err != nil { + return fmt.Errorf("failed to switch profile: %w", err) + } + + log.Infof("switched to profile: %s", profileName) + return nil +} + +// AddProfile creates a new profile +func (pm *ProfileManager) AddProfile(profileName string) error { + // Use ServiceManager (creates profile in profiles/ directory) + if err := pm.serviceMgr.AddProfile(profileName, androidUsername); err != nil { + return fmt.Errorf("failed to add profile: %w", err) + } + + log.Infof("created new profile: %s", profileName) + return nil +} + +// LogoutProfile logs out from a profile (clears authentication) +func (pm *ProfileManager) LogoutProfile(profileName string) error { + profileName = sanitizeProfileName(profileName) + + configPath, err := pm.getProfileConfigPath(profileName) + if err != nil { + return err + } + + // Check if profile exists + if _, err := os.Stat(configPath); os.IsNotExist(err) { + return fmt.Errorf("profile '%s' does not exist", profileName) + } + + // Read current config using internal profilemanager + config, err := profilemanager.ReadConfig(configPath) + if err != nil { + return fmt.Errorf("failed to read profile config: %w", err) + } + + // Clear authentication by removing private key and SSH key + config.PrivateKey = "" + config.SSHKey = "" + + // Save config using internal profilemanager + if err := profilemanager.WriteOutConfig(configPath, config); err != nil { + return fmt.Errorf("failed to save config: %w", err) + } + + log.Infof("logged out from profile: %s", profileName) + return nil +} + +// RemoveProfile deletes a profile +func (pm *ProfileManager) RemoveProfile(profileName string) error { + // Use ServiceManager (removes profile from profiles/ directory) + if err := pm.serviceMgr.RemoveProfile(profileName, androidUsername); err != nil { + return fmt.Errorf("failed to remove profile: %w", err) + } + + log.Infof("removed profile: %s", profileName) + return nil +} + +// getProfileConfigPath returns the config file path for a profile +// This is needed for Android-specific path handling (netbird.cfg for default profile) +func (pm *ProfileManager) getProfileConfigPath(profileName string) (string, error) { + if profileName == "" || profileName == profilemanager.DefaultProfileName { + // Android uses netbird.cfg for default profile instead of default.json + // Default profile is stored in root configDir, not in profiles/ + return filepath.Join(pm.configDir, defaultConfigFilename), nil + } + + // Non-default profiles are stored in profiles subdirectory + // This matches the Java Preferences.java expectation + profileName = sanitizeProfileName(profileName) + profilesDir := filepath.Join(pm.configDir, profilesSubdir) + return filepath.Join(profilesDir, profileName+".json"), nil +} + +// GetConfigPath returns the config file path for a given profile +// Java should call this instead of constructing paths with Preferences.configFile() +func (pm *ProfileManager) GetConfigPath(profileName string) (string, error) { + return pm.getProfileConfigPath(profileName) +} + +// GetStateFilePath returns the state file path for a given profile +// Java should call this instead of constructing paths with Preferences.stateFile() +func (pm *ProfileManager) GetStateFilePath(profileName string) (string, error) { + if profileName == "" || profileName == profilemanager.DefaultProfileName { + return filepath.Join(pm.configDir, "state.json"), nil + } + + profileName = sanitizeProfileName(profileName) + profilesDir := filepath.Join(pm.configDir, profilesSubdir) + return filepath.Join(profilesDir, profileName+".state.json"), nil +} + +// GetActiveConfigPath returns the config file path for the currently active profile +// Java should call this instead of Preferences.getActiveProfileName() + Preferences.configFile() +func (pm *ProfileManager) GetActiveConfigPath() (string, error) { + activeProfile, err := pm.GetActiveProfile() + if err != nil { + return "", fmt.Errorf("failed to get active profile: %w", err) + } + return pm.GetConfigPath(activeProfile) +} + +// GetActiveStateFilePath returns the state file path for the currently active profile +// Java should call this instead of Preferences.getActiveProfileName() + Preferences.stateFile() +func (pm *ProfileManager) GetActiveStateFilePath() (string, error) { + activeProfile, err := pm.GetActiveProfile() + if err != nil { + return "", fmt.Errorf("failed to get active profile: %w", err) + } + return pm.GetStateFilePath(activeProfile) +} + +// sanitizeProfileName removes invalid characters from profile name +func sanitizeProfileName(name string) string { + // Keep only alphanumeric, underscore, and hyphen + var result strings.Builder + for _, r := range name { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || + (r >= '0' && r <= '9') || r == '_' || r == '-' { + result.WriteRune(r) + } + } + return result.String() +} diff --git a/client/cmd/root.go b/client/cmd/root.go index 9f2eb109c..30120c196 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -85,6 +85,9 @@ var ( // Execute executes the root command. func Execute() error { + if isUpdateBinary() { + return updateCmd.Execute() + } return rootCmd.Execute() } diff --git a/client/cmd/signer/artifactkey.go b/client/cmd/signer/artifactkey.go new file mode 100644 index 000000000..5e656650b --- /dev/null +++ b/client/cmd/signer/artifactkey.go @@ -0,0 +1,176 @@ +package main + +import ( + "fmt" + "os" + "time" + + "github.com/spf13/cobra" + + "github.com/netbirdio/netbird/client/internal/updatemanager/reposign" +) + +var ( + bundlePubKeysRootPrivKeyFile string + bundlePubKeysPubKeyFiles []string + bundlePubKeysFile string + + createArtifactKeyRootPrivKeyFile string + createArtifactKeyPrivKeyFile string + createArtifactKeyPubKeyFile string + createArtifactKeyExpiration time.Duration +) + +var createArtifactKeyCmd = &cobra.Command{ + Use: "create-artifact-key", + Short: "Create a new artifact signing key", + Long: `Generate a new artifact signing key pair signed by the root private key. +The artifact key will be used to sign software artifacts/updates.`, + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { + if createArtifactKeyExpiration <= 0 { + return fmt.Errorf("--expiration must be a positive duration (e.g., 720h, 365d, 8760h)") + } + + if err := handleCreateArtifactKey(cmd, createArtifactKeyRootPrivKeyFile, createArtifactKeyPrivKeyFile, createArtifactKeyPubKeyFile, createArtifactKeyExpiration); err != nil { + return fmt.Errorf("failed to create artifact key: %w", err) + } + return nil + }, +} + +var bundlePubKeysCmd = &cobra.Command{ + Use: "bundle-pub-keys", + Short: "Bundle multiple artifact public keys into a signed package", + Long: `Bundle one or more artifact public keys into a signed package using the root private key. +This command is typically used to distribute or authorize a set of valid artifact signing keys.`, + RunE: func(cmd *cobra.Command, args []string) error { + if len(bundlePubKeysPubKeyFiles) == 0 { + return fmt.Errorf("at least one --artifact-pub-key-file must be provided") + } + + if err := handleBundlePubKeys(cmd, bundlePubKeysRootPrivKeyFile, bundlePubKeysPubKeyFiles, bundlePubKeysFile); err != nil { + return fmt.Errorf("failed to bundle public keys: %w", err) + } + return nil + }, +} + +func init() { + rootCmd.AddCommand(createArtifactKeyCmd) + + createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyRootPrivKeyFile, "root-private-key-file", "", "Path to the root private key file used to sign the artifact key") + createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyPrivKeyFile, "artifact-priv-key-file", "", "Path where the artifact private key will be saved") + createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyPubKeyFile, "artifact-pub-key-file", "", "Path where the artifact public key will be saved") + createArtifactKeyCmd.Flags().DurationVar(&createArtifactKeyExpiration, "expiration", 0, "Expiration duration for the artifact key (e.g., 720h, 365d, 8760h)") + + if err := createArtifactKeyCmd.MarkFlagRequired("root-private-key-file"); err != nil { + panic(fmt.Errorf("mark root-private-key-file as required: %w", err)) + } + if err := createArtifactKeyCmd.MarkFlagRequired("artifact-priv-key-file"); err != nil { + panic(fmt.Errorf("mark artifact-priv-key-file as required: %w", err)) + } + if err := createArtifactKeyCmd.MarkFlagRequired("artifact-pub-key-file"); err != nil { + panic(fmt.Errorf("mark artifact-pub-key-file as required: %w", err)) + } + if err := createArtifactKeyCmd.MarkFlagRequired("expiration"); err != nil { + panic(fmt.Errorf("mark expiration as required: %w", err)) + } + + rootCmd.AddCommand(bundlePubKeysCmd) + + bundlePubKeysCmd.Flags().StringVar(&bundlePubKeysRootPrivKeyFile, "root-private-key-file", "", "Path to the root private key file used to sign the bundle") + bundlePubKeysCmd.Flags().StringArrayVar(&bundlePubKeysPubKeyFiles, "artifact-pub-key-file", nil, "Path(s) to the artifact public key files to include in the bundle (can be repeated)") + bundlePubKeysCmd.Flags().StringVar(&bundlePubKeysFile, "bundle-pub-key-file", "", "Path where the public keys will be saved") + + if err := bundlePubKeysCmd.MarkFlagRequired("root-private-key-file"); err != nil { + panic(fmt.Errorf("mark root-private-key-file as required: %w", err)) + } + if err := bundlePubKeysCmd.MarkFlagRequired("artifact-pub-key-file"); err != nil { + panic(fmt.Errorf("mark artifact-pub-key-file as required: %w", err)) + } + if err := bundlePubKeysCmd.MarkFlagRequired("bundle-pub-key-file"); err != nil { + panic(fmt.Errorf("mark bundle-pub-key-file as required: %w", err)) + } +} + +func handleCreateArtifactKey(cmd *cobra.Command, rootPrivKeyFile, artifactPrivKeyFile, artifactPubKeyFile string, expiration time.Duration) error { + cmd.Println("Creating new artifact signing key...") + + privKeyPEM, err := os.ReadFile(rootPrivKeyFile) + if err != nil { + return fmt.Errorf("read root private key file: %w", err) + } + + privateRootKey, err := reposign.ParseRootKey(privKeyPEM) + if err != nil { + return fmt.Errorf("failed to parse private root key: %w", err) + } + + artifactKey, privPEM, pubPEM, signature, err := reposign.GenerateArtifactKey(privateRootKey, expiration) + if err != nil { + return fmt.Errorf("generate artifact key: %w", err) + } + + if err := os.WriteFile(artifactPrivKeyFile, privPEM, 0o600); err != nil { + return fmt.Errorf("write private key file (%s): %w", artifactPrivKeyFile, err) + } + + if err := os.WriteFile(artifactPubKeyFile, pubPEM, 0o600); err != nil { + return fmt.Errorf("write public key file (%s): %w", artifactPubKeyFile, err) + } + + signatureFile := artifactPubKeyFile + ".sig" + if err := os.WriteFile(signatureFile, signature, 0o600); err != nil { + return fmt.Errorf("write signature file (%s): %w", signatureFile, err) + } + + cmd.Printf("✅ Artifact key created successfully.\n") + cmd.Printf("%s\n", artifactKey.String()) + return nil +} + +func handleBundlePubKeys(cmd *cobra.Command, rootPrivKeyFile string, artifactPubKeyFiles []string, bundlePubKeysFile string) error { + cmd.Println("📦 Bundling public keys into signed package...") + + privKeyPEM, err := os.ReadFile(rootPrivKeyFile) + if err != nil { + return fmt.Errorf("read root private key file: %w", err) + } + + privateRootKey, err := reposign.ParseRootKey(privKeyPEM) + if err != nil { + return fmt.Errorf("failed to parse private root key: %w", err) + } + + publicKeys := make([]reposign.PublicKey, 0, len(artifactPubKeyFiles)) + for _, pubFile := range artifactPubKeyFiles { + pubPem, err := os.ReadFile(pubFile) + if err != nil { + return fmt.Errorf("read public key file: %w", err) + } + + pk, err := reposign.ParseArtifactPubKey(pubPem) + if err != nil { + return fmt.Errorf("failed to parse artifact key: %w", err) + } + publicKeys = append(publicKeys, pk) + } + + parsedKeys, signature, err := reposign.BundleArtifactKeys(privateRootKey, publicKeys) + if err != nil { + return fmt.Errorf("bundle artifact keys: %w", err) + } + + if err := os.WriteFile(bundlePubKeysFile, parsedKeys, 0o600); err != nil { + return fmt.Errorf("write public keys file (%s): %w", bundlePubKeysFile, err) + } + + signatureFile := bundlePubKeysFile + ".sig" + if err := os.WriteFile(signatureFile, signature, 0o600); err != nil { + return fmt.Errorf("write signature file (%s): %w", signatureFile, err) + } + + cmd.Printf("✅ Bundle created with %d public keys.\n", len(artifactPubKeyFiles)) + return nil +} diff --git a/client/cmd/signer/artifactsign.go b/client/cmd/signer/artifactsign.go new file mode 100644 index 000000000..881be9367 --- /dev/null +++ b/client/cmd/signer/artifactsign.go @@ -0,0 +1,276 @@ +package main + +import ( + "fmt" + "os" + + "github.com/spf13/cobra" + + "github.com/netbirdio/netbird/client/internal/updatemanager/reposign" +) + +const ( + envArtifactPrivateKey = "NB_ARTIFACT_PRIV_KEY" +) + +var ( + signArtifactPrivKeyFile string + signArtifactArtifactFile string + + verifyArtifactPubKeyFile string + verifyArtifactFile string + verifyArtifactSignatureFile string + + verifyArtifactKeyPubKeyFile string + verifyArtifactKeyRootPubKeyFile string + verifyArtifactKeySignatureFile string + verifyArtifactKeyRevocationFile string +) + +var signArtifactCmd = &cobra.Command{ + Use: "sign-artifact", + Short: "Sign an artifact using an artifact private key", + Long: `Sign a software artifact (e.g., update bundle or binary) using the artifact's private key. +This command produces a detached signature that can be verified using the corresponding artifact public key.`, + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { + if err := handleSignArtifact(cmd, signArtifactPrivKeyFile, signArtifactArtifactFile); err != nil { + return fmt.Errorf("failed to sign artifact: %w", err) + } + return nil + }, +} + +var verifyArtifactCmd = &cobra.Command{ + Use: "verify-artifact", + Short: "Verify an artifact signature using an artifact public key", + Long: `Verify a software artifact signature using the artifact's public key.`, + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { + if err := handleVerifyArtifact(cmd, verifyArtifactPubKeyFile, verifyArtifactFile, verifyArtifactSignatureFile); err != nil { + return fmt.Errorf("failed to verify artifact: %w", err) + } + return nil + }, +} + +var verifyArtifactKeyCmd = &cobra.Command{ + Use: "verify-artifact-key", + Short: "Verify an artifact public key was signed by a root key", + Long: `Verify that an artifact public key (or bundle) was properly signed by a root key. +This validates the chain of trust from the root key to the artifact key.`, + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { + if err := handleVerifyArtifactKey(cmd, verifyArtifactKeyPubKeyFile, verifyArtifactKeyRootPubKeyFile, verifyArtifactKeySignatureFile, verifyArtifactKeyRevocationFile); err != nil { + return fmt.Errorf("failed to verify artifact key: %w", err) + } + return nil + }, +} + +func init() { + rootCmd.AddCommand(signArtifactCmd) + rootCmd.AddCommand(verifyArtifactCmd) + rootCmd.AddCommand(verifyArtifactKeyCmd) + + signArtifactCmd.Flags().StringVar(&signArtifactPrivKeyFile, "artifact-key-file", "", fmt.Sprintf("Path to the artifact private key file used for signing (or set %s env var)", envArtifactPrivateKey)) + signArtifactCmd.Flags().StringVar(&signArtifactArtifactFile, "artifact-file", "", "Path to the artifact to be signed") + + // artifact-file is required, but artifact-key-file can come from env var + if err := signArtifactCmd.MarkFlagRequired("artifact-file"); err != nil { + panic(fmt.Errorf("mark artifact-file as required: %w", err)) + } + + verifyArtifactCmd.Flags().StringVar(&verifyArtifactPubKeyFile, "artifact-public-key-file", "", "Path to the artifact public key file") + verifyArtifactCmd.Flags().StringVar(&verifyArtifactFile, "artifact-file", "", "Path to the artifact to be verified") + verifyArtifactCmd.Flags().StringVar(&verifyArtifactSignatureFile, "signature-file", "", "Path to the signature file") + + if err := verifyArtifactCmd.MarkFlagRequired("artifact-public-key-file"); err != nil { + panic(fmt.Errorf("mark artifact-public-key-file as required: %w", err)) + } + if err := verifyArtifactCmd.MarkFlagRequired("artifact-file"); err != nil { + panic(fmt.Errorf("mark artifact-file as required: %w", err)) + } + if err := verifyArtifactCmd.MarkFlagRequired("signature-file"); err != nil { + panic(fmt.Errorf("mark signature-file as required: %w", err)) + } + + verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyPubKeyFile, "artifact-key-file", "", "Path to the artifact public key file or bundle") + verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyRootPubKeyFile, "root-key-file", "", "Path to the root public key file or bundle") + verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeySignatureFile, "signature-file", "", "Path to the signature file") + verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyRevocationFile, "revocation-file", "", "Path to the revocation list file (optional)") + + if err := verifyArtifactKeyCmd.MarkFlagRequired("artifact-key-file"); err != nil { + panic(fmt.Errorf("mark artifact-key-file as required: %w", err)) + } + if err := verifyArtifactKeyCmd.MarkFlagRequired("root-key-file"); err != nil { + panic(fmt.Errorf("mark root-key-file as required: %w", err)) + } + if err := verifyArtifactKeyCmd.MarkFlagRequired("signature-file"); err != nil { + panic(fmt.Errorf("mark signature-file as required: %w", err)) + } +} + +func handleSignArtifact(cmd *cobra.Command, privKeyFile, artifactFile string) error { + cmd.Println("🖋️ Signing artifact...") + + // Load private key from env var or file + var privKeyPEM []byte + var err error + + if envKey := os.Getenv(envArtifactPrivateKey); envKey != "" { + // Use key from environment variable + privKeyPEM = []byte(envKey) + } else if privKeyFile != "" { + // Fall back to file + privKeyPEM, err = os.ReadFile(privKeyFile) + if err != nil { + return fmt.Errorf("read private key file: %w", err) + } + } else { + return fmt.Errorf("artifact private key must be provided via %s environment variable or --artifact-key-file flag", envArtifactPrivateKey) + } + + privateKey, err := reposign.ParseArtifactKey(privKeyPEM) + if err != nil { + return fmt.Errorf("failed to parse artifact private key: %w", err) + } + + artifactData, err := os.ReadFile(artifactFile) + if err != nil { + return fmt.Errorf("read artifact file: %w", err) + } + + signature, err := reposign.SignData(privateKey, artifactData) + if err != nil { + return fmt.Errorf("sign artifact: %w", err) + } + + sigFile := artifactFile + ".sig" + if err := os.WriteFile(artifactFile+".sig", signature, 0o600); err != nil { + return fmt.Errorf("write signature file (%s): %w", sigFile, err) + } + + cmd.Printf("✅ Artifact signed successfully.\n") + cmd.Printf("Signature file: %s\n", sigFile) + return nil +} + +func handleVerifyArtifact(cmd *cobra.Command, pubKeyFile, artifactFile, signatureFile string) error { + cmd.Println("🔍 Verifying artifact...") + + // Read artifact public key + pubKeyPEM, err := os.ReadFile(pubKeyFile) + if err != nil { + return fmt.Errorf("read public key file: %w", err) + } + + publicKey, err := reposign.ParseArtifactPubKey(pubKeyPEM) + if err != nil { + return fmt.Errorf("failed to parse artifact public key: %w", err) + } + + // Read artifact data + artifactData, err := os.ReadFile(artifactFile) + if err != nil { + return fmt.Errorf("read artifact file: %w", err) + } + + // Read signature + sigBytes, err := os.ReadFile(signatureFile) + if err != nil { + return fmt.Errorf("read signature file: %w", err) + } + + signature, err := reposign.ParseSignature(sigBytes) + if err != nil { + return fmt.Errorf("failed to parse signature: %w", err) + } + + // Validate artifact + if err := reposign.ValidateArtifact([]reposign.PublicKey{publicKey}, artifactData, *signature); err != nil { + return fmt.Errorf("artifact verification failed: %w", err) + } + + cmd.Println("✅ Artifact signature is valid") + cmd.Printf("Artifact: %s\n", artifactFile) + cmd.Printf("Signed by key: %s\n", signature.KeyID) + cmd.Printf("Signature timestamp: %s\n", signature.Timestamp.Format("2006-01-02 15:04:05 MST")) + return nil +} + +func handleVerifyArtifactKey(cmd *cobra.Command, artifactKeyFile, rootKeyFile, signatureFile, revocationFile string) error { + cmd.Println("🔍 Verifying artifact key...") + + // Read artifact key data + artifactKeyData, err := os.ReadFile(artifactKeyFile) + if err != nil { + return fmt.Errorf("read artifact key file: %w", err) + } + + // Read root public key(s) + rootKeyData, err := os.ReadFile(rootKeyFile) + if err != nil { + return fmt.Errorf("read root key file: %w", err) + } + + rootPublicKeys, err := parseRootPublicKeys(rootKeyData) + if err != nil { + return fmt.Errorf("failed to parse root public key(s): %w", err) + } + + // Read signature + sigBytes, err := os.ReadFile(signatureFile) + if err != nil { + return fmt.Errorf("read signature file: %w", err) + } + + signature, err := reposign.ParseSignature(sigBytes) + if err != nil { + return fmt.Errorf("failed to parse signature: %w", err) + } + + // Read optional revocation list + var revocationList *reposign.RevocationList + if revocationFile != "" { + revData, err := os.ReadFile(revocationFile) + if err != nil { + return fmt.Errorf("read revocation file: %w", err) + } + + revocationList, err = reposign.ParseRevocationList(revData) + if err != nil { + return fmt.Errorf("failed to parse revocation list: %w", err) + } + } + + // Validate artifact key(s) + validKeys, err := reposign.ValidateArtifactKeys(rootPublicKeys, artifactKeyData, *signature, revocationList) + if err != nil { + return fmt.Errorf("artifact key verification failed: %w", err) + } + + cmd.Println("✅ Artifact key(s) verified successfully") + cmd.Printf("Signed by root key: %s\n", signature.KeyID) + cmd.Printf("Signature timestamp: %s\n", signature.Timestamp.Format("2006-01-02 15:04:05 MST")) + cmd.Printf("\nValid artifact keys (%d):\n", len(validKeys)) + for i, key := range validKeys { + cmd.Printf(" [%d] Key ID: %s\n", i+1, key.Metadata.ID) + cmd.Printf(" Created: %s\n", key.Metadata.CreatedAt.Format("2006-01-02 15:04:05 MST")) + if !key.Metadata.ExpiresAt.IsZero() { + cmd.Printf(" Expires: %s\n", key.Metadata.ExpiresAt.Format("2006-01-02 15:04:05 MST")) + } else { + cmd.Printf(" Expires: Never\n") + } + } + return nil +} + +// parseRootPublicKeys parses a root public key from PEM data +func parseRootPublicKeys(data []byte) ([]reposign.PublicKey, error) { + key, err := reposign.ParseRootPublicKey(data) + if err != nil { + return nil, err + } + return []reposign.PublicKey{key}, nil +} diff --git a/client/cmd/signer/main.go b/client/cmd/signer/main.go new file mode 100644 index 000000000..407093d07 --- /dev/null +++ b/client/cmd/signer/main.go @@ -0,0 +1,21 @@ +package main + +import ( + "os" + + "github.com/spf13/cobra" +) + +var rootCmd = &cobra.Command{ + Use: "signer", + Short: "A CLI tool for managing cryptographic keys and artifacts", + Long: `signer is a command-line tool that helps you manage +root keys, artifact keys, and revocation lists securely.`, +} + +func main() { + if err := rootCmd.Execute(); err != nil { + rootCmd.Println(err) + os.Exit(1) + } +} diff --git a/client/cmd/signer/revocation.go b/client/cmd/signer/revocation.go new file mode 100644 index 000000000..1d84b65c3 --- /dev/null +++ b/client/cmd/signer/revocation.go @@ -0,0 +1,220 @@ +package main + +import ( + "fmt" + "os" + "time" + + "github.com/spf13/cobra" + + "github.com/netbirdio/netbird/client/internal/updatemanager/reposign" +) + +const ( + defaultRevocationListExpiration = 365 * 24 * time.Hour // 1 year +) + +var ( + keyID string + revocationListFile string + privateRootKeyFile string + publicRootKeyFile string + signatureFile string + expirationDuration time.Duration +) + +var createRevocationListCmd = &cobra.Command{ + Use: "create-revocation-list", + Short: "Create a new revocation list signed by the private root key", + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { + return handleCreateRevocationList(cmd, revocationListFile, privateRootKeyFile) + }, +} + +var extendRevocationListCmd = &cobra.Command{ + Use: "extend-revocation-list", + Short: "Extend an existing revocation list with a given key ID", + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { + return handleExtendRevocationList(cmd, keyID, revocationListFile, privateRootKeyFile) + }, +} + +var verifyRevocationListCmd = &cobra.Command{ + Use: "verify-revocation-list", + Short: "Verify a revocation list signature using the public root key", + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { + return handleVerifyRevocationList(cmd, revocationListFile, signatureFile, publicRootKeyFile) + }, +} + +func init() { + rootCmd.AddCommand(createRevocationListCmd) + rootCmd.AddCommand(extendRevocationListCmd) + rootCmd.AddCommand(verifyRevocationListCmd) + + createRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the existing revocation list file") + createRevocationListCmd.Flags().StringVar(&privateRootKeyFile, "private-root-key", "", "Path to the private root key PEM file") + createRevocationListCmd.Flags().DurationVar(&expirationDuration, "expiration", defaultRevocationListExpiration, "Expiration duration for the revocation list (e.g., 8760h for 1 year)") + if err := createRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil { + panic(err) + } + if err := createRevocationListCmd.MarkFlagRequired("private-root-key"); err != nil { + panic(err) + } + + extendRevocationListCmd.Flags().StringVar(&keyID, "key-id", "", "ID of the key to extend the revocation list for") + extendRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the existing revocation list file") + extendRevocationListCmd.Flags().StringVar(&privateRootKeyFile, "private-root-key", "", "Path to the private root key PEM file") + extendRevocationListCmd.Flags().DurationVar(&expirationDuration, "expiration", defaultRevocationListExpiration, "Expiration duration for the revocation list (e.g., 8760h for 1 year)") + if err := extendRevocationListCmd.MarkFlagRequired("key-id"); err != nil { + panic(err) + } + if err := extendRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil { + panic(err) + } + if err := extendRevocationListCmd.MarkFlagRequired("private-root-key"); err != nil { + panic(err) + } + + verifyRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the revocation list file") + verifyRevocationListCmd.Flags().StringVar(&signatureFile, "signature-file", "", "Path to the signature file") + verifyRevocationListCmd.Flags().StringVar(&publicRootKeyFile, "public-root-key", "", "Path to the public root key PEM file") + if err := verifyRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil { + panic(err) + } + if err := verifyRevocationListCmd.MarkFlagRequired("signature-file"); err != nil { + panic(err) + } + if err := verifyRevocationListCmd.MarkFlagRequired("public-root-key"); err != nil { + panic(err) + } +} + +func handleCreateRevocationList(cmd *cobra.Command, revocationListFile string, privateRootKeyFile string) error { + privKeyPEM, err := os.ReadFile(privateRootKeyFile) + if err != nil { + return fmt.Errorf("failed to read private root key file: %w", err) + } + + privateRootKey, err := reposign.ParseRootKey(privKeyPEM) + if err != nil { + return fmt.Errorf("failed to parse private root key: %w", err) + } + + rlBytes, sigBytes, err := reposign.CreateRevocationList(*privateRootKey, expirationDuration) + if err != nil { + return fmt.Errorf("failed to create revocation list: %w", err) + } + + if err := writeOutputFiles(revocationListFile, revocationListFile+".sig", rlBytes, sigBytes); err != nil { + return fmt.Errorf("failed to write output files: %w", err) + } + + cmd.Println("✅ Revocation list created successfully") + return nil +} + +func handleExtendRevocationList(cmd *cobra.Command, keyID, revocationListFile, privateRootKeyFile string) error { + privKeyPEM, err := os.ReadFile(privateRootKeyFile) + if err != nil { + return fmt.Errorf("failed to read private root key file: %w", err) + } + + privateRootKey, err := reposign.ParseRootKey(privKeyPEM) + if err != nil { + return fmt.Errorf("failed to parse private root key: %w", err) + } + + rlBytes, err := os.ReadFile(revocationListFile) + if err != nil { + return fmt.Errorf("failed to read revocation list file: %w", err) + } + + rl, err := reposign.ParseRevocationList(rlBytes) + if err != nil { + return fmt.Errorf("failed to parse revocation list: %w", err) + } + + kid, err := reposign.ParseKeyID(keyID) + if err != nil { + return fmt.Errorf("invalid key ID: %w", err) + } + + newRLBytes, sigBytes, err := reposign.ExtendRevocationList(*privateRootKey, *rl, kid, expirationDuration) + if err != nil { + return fmt.Errorf("failed to extend revocation list: %w", err) + } + + if err := writeOutputFiles(revocationListFile, revocationListFile+".sig", newRLBytes, sigBytes); err != nil { + return fmt.Errorf("failed to write output files: %w", err) + } + + cmd.Println("✅ Revocation list extended successfully") + return nil +} + +func handleVerifyRevocationList(cmd *cobra.Command, revocationListFile, signatureFile, publicRootKeyFile string) error { + // Read revocation list file + rlBytes, err := os.ReadFile(revocationListFile) + if err != nil { + return fmt.Errorf("failed to read revocation list file: %w", err) + } + + // Read signature file + sigBytes, err := os.ReadFile(signatureFile) + if err != nil { + return fmt.Errorf("failed to read signature file: %w", err) + } + + // Read public root key file + pubKeyPEM, err := os.ReadFile(publicRootKeyFile) + if err != nil { + return fmt.Errorf("failed to read public root key file: %w", err) + } + + // Parse public root key + publicKey, err := reposign.ParseRootPublicKey(pubKeyPEM) + if err != nil { + return fmt.Errorf("failed to parse public root key: %w", err) + } + + // Parse signature + signature, err := reposign.ParseSignature(sigBytes) + if err != nil { + return fmt.Errorf("failed to parse signature: %w", err) + } + + // Validate revocation list + rl, err := reposign.ValidateRevocationList([]reposign.PublicKey{publicKey}, rlBytes, *signature) + if err != nil { + return fmt.Errorf("failed to validate revocation list: %w", err) + } + + // Display results + cmd.Println("✅ Revocation list signature is valid") + cmd.Printf("Last Updated: %s\n", rl.LastUpdated.Format(time.RFC3339)) + cmd.Printf("Expires At: %s\n", rl.ExpiresAt.Format(time.RFC3339)) + cmd.Printf("Number of revoked keys: %d\n", len(rl.Revoked)) + + if len(rl.Revoked) > 0 { + cmd.Println("\nRevoked Keys:") + for keyID, revokedTime := range rl.Revoked { + cmd.Printf(" - %s (revoked at: %s)\n", keyID, revokedTime.Format(time.RFC3339)) + } + } + + return nil +} + +func writeOutputFiles(rlPath, sigPath string, rlBytes, sigBytes []byte) error { + if err := os.WriteFile(rlPath, rlBytes, 0o600); err != nil { + return fmt.Errorf("failed to write revocation list file: %w", err) + } + if err := os.WriteFile(sigPath, sigBytes, 0o600); err != nil { + return fmt.Errorf("failed to write signature file: %w", err) + } + return nil +} diff --git a/client/cmd/signer/rootkey.go b/client/cmd/signer/rootkey.go new file mode 100644 index 000000000..78ac36b41 --- /dev/null +++ b/client/cmd/signer/rootkey.go @@ -0,0 +1,74 @@ +package main + +import ( + "fmt" + "os" + "time" + + "github.com/spf13/cobra" + + "github.com/netbirdio/netbird/client/internal/updatemanager/reposign" +) + +var ( + privKeyFile string + pubKeyFile string + rootExpiration time.Duration +) + +var createRootKeyCmd = &cobra.Command{ + Use: "create-root-key", + Short: "Create a new root key pair", + Long: `Create a new root key pair and specify an expiration time for it.`, + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { + // Validate expiration + if rootExpiration <= 0 { + return fmt.Errorf("--expiration must be a positive duration (e.g., 720h, 365d, 8760h)") + } + + // Run main logic + if err := handleGenerateRootKey(cmd, privKeyFile, pubKeyFile, rootExpiration); err != nil { + return fmt.Errorf("failed to generate root key: %w", err) + } + return nil + }, +} + +func init() { + rootCmd.AddCommand(createRootKeyCmd) + createRootKeyCmd.Flags().StringVar(&privKeyFile, "priv-key-file", "", "Path to output private key file") + createRootKeyCmd.Flags().StringVar(&pubKeyFile, "pub-key-file", "", "Path to output public key file") + createRootKeyCmd.Flags().DurationVar(&rootExpiration, "expiration", 0, "Expiration time for the root key (e.g., 720h,)") + + if err := createRootKeyCmd.MarkFlagRequired("priv-key-file"); err != nil { + panic(err) + } + if err := createRootKeyCmd.MarkFlagRequired("pub-key-file"); err != nil { + panic(err) + } + if err := createRootKeyCmd.MarkFlagRequired("expiration"); err != nil { + panic(err) + } +} + +func handleGenerateRootKey(cmd *cobra.Command, privKeyFile, pubKeyFile string, expiration time.Duration) error { + rk, privPEM, pubPEM, err := reposign.GenerateRootKey(expiration) + if err != nil { + return fmt.Errorf("generate root key: %w", err) + } + + // Write private key + if err := os.WriteFile(privKeyFile, privPEM, 0o600); err != nil { + return fmt.Errorf("write private key file (%s): %w", privKeyFile, err) + } + + // Write public key + if err := os.WriteFile(pubKeyFile, pubPEM, 0o600); err != nil { + return fmt.Errorf("write public key file (%s): %w", pubKeyFile, err) + } + + cmd.Printf("%s\n\n", rk.String()) + cmd.Printf("✅ Root key pair generated successfully.\n") + return nil +} diff --git a/client/cmd/ssh.go b/client/cmd/ssh.go index 525bcdef1..0acf0b133 100644 --- a/client/cmd/ssh.go +++ b/client/cmd/ssh.go @@ -634,7 +634,11 @@ func parseAndStartLocalForward(ctx context.Context, c *sshclient.Client, forward return err } - cmd.Printf("Local port forwarding: %s -> %s\n", localAddr, remoteAddr) + if err := validateDestinationPort(remoteAddr); err != nil { + return fmt.Errorf("invalid remote address: %w", err) + } + + log.Debugf("Local port forwarding: %s -> %s", localAddr, remoteAddr) go func() { if err := c.LocalPortForward(ctx, localAddr, remoteAddr); err != nil && !errors.Is(err, context.Canceled) { @@ -652,7 +656,11 @@ func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forwar return err } - cmd.Printf("Remote port forwarding: %s -> %s\n", remoteAddr, localAddr) + if err := validateDestinationPort(localAddr); err != nil { + return fmt.Errorf("invalid local address: %w", err) + } + + log.Debugf("Remote port forwarding: %s -> %s", remoteAddr, localAddr) go func() { if err := c.RemotePortForward(ctx, remoteAddr, localAddr); err != nil && !errors.Is(err, context.Canceled) { @@ -663,6 +671,35 @@ func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forwar return nil } +// validateDestinationPort checks that the destination address has a valid port. +// Port 0 is only valid for bind addresses (where the OS picks an available port), +// not for destination addresses where we need to connect. +func validateDestinationPort(addr string) error { + if strings.HasPrefix(addr, "/") || strings.HasPrefix(addr, "./") { + return nil + } + + _, portStr, err := net.SplitHostPort(addr) + if err != nil { + return fmt.Errorf("parse address %s: %w", addr, err) + } + + port, err := strconv.Atoi(portStr) + if err != nil { + return fmt.Errorf("invalid port %s: %w", portStr, err) + } + + if port == 0 { + return fmt.Errorf("port 0 is not valid for destination address") + } + + if port < 0 || port > 65535 { + return fmt.Errorf("port %d out of range (1-65535)", port) + } + + return nil +} + // parsePortForwardSpec parses port forward specifications like "8080:localhost:80" or "[::1]:8080:localhost:80". // Also supports Unix sockets like "8080:/tmp/socket" or "127.0.0.1:8080:/tmp/socket". func parsePortForwardSpec(spec string) (string, string, error) { diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index b9ff35945..888a9a3f7 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -127,7 +127,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp if err != nil { t.Fatal(err) } - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil) if err != nil { t.Fatal(err) } diff --git a/client/cmd/up.go b/client/cmd/up.go index 140ba2cb2..9efc2e60d 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -197,7 +197,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr r := peer.NewRecorder(config.ManagementURL.String()) r.GetFullStatus() - connectClient := internal.NewConnectClient(ctx, config, r) + connectClient := internal.NewConnectClient(ctx, config, r, false) SetupDebugHandler(ctx, config, r, connectClient, "") return connectClient.Run(nil) diff --git a/client/cmd/update.go b/client/cmd/update.go new file mode 100644 index 000000000..dc49b02c3 --- /dev/null +++ b/client/cmd/update.go @@ -0,0 +1,13 @@ +//go:build !windows && !darwin + +package cmd + +import ( + "github.com/spf13/cobra" +) + +var updateCmd *cobra.Command + +func isUpdateBinary() bool { + return false +} diff --git a/client/cmd/update_supported.go b/client/cmd/update_supported.go new file mode 100644 index 000000000..977875093 --- /dev/null +++ b/client/cmd/update_supported.go @@ -0,0 +1,75 @@ +//go:build windows || darwin + +package cmd + +import ( + "context" + "os" + "path/filepath" + "strings" + + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + + "github.com/netbirdio/netbird/client/internal/updatemanager/installer" + "github.com/netbirdio/netbird/util" +) + +var ( + updateCmd = &cobra.Command{ + Use: "update", + Short: "Update the NetBird client application", + RunE: updateFunc, + } + + tempDirFlag string + installerFile string + serviceDirFlag string + dryRunFlag bool +) + +func init() { + updateCmd.Flags().StringVar(&tempDirFlag, "temp-dir", "", "temporary dir") + updateCmd.Flags().StringVar(&installerFile, "installer-file", "", "installer file") + updateCmd.Flags().StringVar(&serviceDirFlag, "service-dir", "", "service directory") + updateCmd.Flags().BoolVar(&dryRunFlag, "dry-run", false, "dry run the update process without making any changes") +} + +// isUpdateBinary checks if the current executable is named "update" or "update.exe" +func isUpdateBinary() bool { + // Remove extension for cross-platform compatibility + execPath, err := os.Executable() + if err != nil { + return false + } + baseName := filepath.Base(execPath) + name := strings.TrimSuffix(baseName, filepath.Ext(baseName)) + + return name == installer.UpdaterBinaryNameWithoutExtension() +} + +func updateFunc(cmd *cobra.Command, args []string) error { + if err := setupLogToFile(tempDirFlag); err != nil { + return err + } + + log.Infof("updater started: %s", serviceDirFlag) + updater := installer.NewWithDir(tempDirFlag) + if err := updater.Setup(context.Background(), dryRunFlag, installerFile, serviceDirFlag); err != nil { + log.Errorf("failed to update application: %v", err) + return err + } + return nil +} + +func setupLogToFile(dir string) error { + logFile := filepath.Join(dir, installer.LogFile) + + if _, err := os.Stat(logFile); err == nil { + if err := os.Remove(logFile); err != nil { + log.Errorf("failed to remove existing log file: %v\n", err) + } + } + + return util.InitLog(logLevel, util.LogConsole, logFile) +} diff --git a/client/embed/embed.go b/client/embed/embed.go index 3090ca6a2..353c5438f 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -173,7 +173,7 @@ func (c *Client) Start(startCtx context.Context) error { } recorder := peer.NewRecorder(c.config.ManagementURL.String()) - client := internal.NewConnectClient(ctx, c.config, recorder) + client := internal.NewConnectClient(ctx, c.config, recorder, false) // either startup error (permanent backoff err) or nil err (successful engine up) // TODO: make after-startup backoff err available diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index adec802c8..6b29c5606 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -386,6 +386,97 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) { verifyIptablesOutput(t, stdout, stderr) } +func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) { + if check() != NFTABLES { + t.Skip("nftables not supported on this system") + } + + if _, err := exec.LookPath("iptables-save"); err != nil { + t.Skipf("iptables-save not available on this system: %v", err) + } + + // First ensure iptables-nft tables exist by running iptables-save + stdout, stderr := runIptablesSave(t) + verifyIptablesOutput(t, stdout, stderr) + + manager, err := Create(ifaceMock, iface.DefaultMTU) + require.NoError(t, err, "failed to create manager") + require.NoError(t, manager.Init(nil)) + + t.Cleanup(func() { + err := manager.Close(nil) + require.NoError(t, err, "failed to reset manager state") + + // Verify iptables output after reset + stdout, stderr := runIptablesSave(t) + verifyIptablesOutput(t, stdout, stderr) + }) + + const octet2Count = 25 + const octet3Count = 255 + prefixes := make([]netip.Prefix, 0, (octet2Count-1)*(octet3Count-1)) + for i := 1; i < octet2Count; i++ { + for j := 1; j < octet3Count; j++ { + addr := netip.AddrFrom4([4]byte{192, byte(j), byte(i), 0}) + prefixes = append(prefixes, netip.PrefixFrom(addr, 24)) + } + } + _, err = manager.AddRouteFiltering( + nil, + prefixes, + fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")}, + fw.ProtocolTCP, + nil, + &fw.Port{Values: []uint16{443}}, + fw.ActionAccept, + ) + require.NoError(t, err, "failed to add route filtering rule") + + stdout, stderr = runIptablesSave(t) + verifyIptablesOutput(t, stdout, stderr) +} + +func TestNftablesManagerCompatibilityWithIptablesForEmptyPrefixes(t *testing.T) { + if check() != NFTABLES { + t.Skip("nftables not supported on this system") + } + + if _, err := exec.LookPath("iptables-save"); err != nil { + t.Skipf("iptables-save not available on this system: %v", err) + } + + // First ensure iptables-nft tables exist by running iptables-save + stdout, stderr := runIptablesSave(t) + verifyIptablesOutput(t, stdout, stderr) + + manager, err := Create(ifaceMock, iface.DefaultMTU) + require.NoError(t, err, "failed to create manager") + require.NoError(t, manager.Init(nil)) + + t.Cleanup(func() { + err := manager.Close(nil) + require.NoError(t, err, "failed to reset manager state") + + // Verify iptables output after reset + stdout, stderr := runIptablesSave(t) + verifyIptablesOutput(t, stdout, stderr) + }) + + _, err = manager.AddRouteFiltering( + nil, + []netip.Prefix{}, + fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")}, + fw.ProtocolTCP, + nil, + &fw.Port{Values: []uint16{443}}, + fw.ActionAccept, + ) + require.NoError(t, err, "failed to add route filtering rule") + + stdout, stderr = runIptablesSave(t) + verifyIptablesOutput(t, stdout, stderr) +} + func compareExprsIgnoringCounters(t *testing.T, got, want []expr.Any) { t.Helper() require.Equal(t, len(got), len(want), "expression count mismatch") diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 7f95992da..b6e0cf5b2 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -48,9 +48,11 @@ const ( // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation ipTCPHeaderMinSize = 40 -) -const refreshRulesMapError = "refresh rules map: %w" + // maxPrefixesSet 1638 prefixes start to fail, taking some margin + maxPrefixesSet = 1500 + refreshRulesMapError = "refresh rules map: %w" +) var ( errFilterTableNotFound = fmt.Errorf("'filter' table not found") @@ -513,16 +515,35 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err } elements := convertPrefixesToSet(prefixes) - if err := r.conn.AddSet(nfset, elements); err != nil { - return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err) - } + nElements := len(elements) + maxElements := maxPrefixesSet * 2 + initialElements := elements[:min(maxElements, nElements)] + + if err := r.conn.AddSet(nfset, initialElements); err != nil { + return nil, fmt.Errorf("error adding set %s: %w", setName, err) + } if err := r.conn.Flush(); err != nil { return nil, fmt.Errorf("flush error: %w", err) } + log.Debugf("Created new ipset: %s with %d initial prefixes (total prefixes %d)", setName, len(initialElements)/2, len(prefixes)) - log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2) + var subEnd int + for subStart := maxElements; subStart < nElements; subStart += maxElements { + subEnd = min(subStart+maxElements, nElements) + subElement := elements[subStart:subEnd] + nSubPrefixes := len(subElement) / 2 + log.Tracef("Adding new prefixes (%d) in ipset: %s", nSubPrefixes, setName) + if err := r.conn.SetAddElements(nfset, subElement); err != nil { + return nil, fmt.Errorf("error adding prefixes (%d) to set %s: %w", nSubPrefixes, setName, err) + } + if err := r.conn.Flush(); err != nil { + return nil, fmt.Errorf("flush error: %w", err) + } + log.Debugf("Added new prefixes (%d) in ipset: %s", nSubPrefixes, setName) + } + log.Infof("Created new ipset: %s with %d prefixes", setName, len(prefixes)) return nfset, nil } diff --git a/client/iface/device/device_ios.go b/client/iface/device/device_ios.go index f96edf992..d841ac2fe 100644 --- a/client/iface/device/device_ios.go +++ b/client/iface/device/device_ios.go @@ -4,6 +4,7 @@ package device import ( + "fmt" "os" log "github.com/sirupsen/logrus" @@ -45,10 +46,31 @@ func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu } } +// ErrInvalidTunnelFD is returned when the tunnel file descriptor is invalid (0). +// This typically means the Swift code couldn't find the utun control socket. +var ErrInvalidTunnelFD = fmt.Errorf("invalid tunnel file descriptor: fd is 0 (Swift failed to locate utun socket)") + func (t *TunDevice) Create() (WGConfigurer, error) { log.Infof("create tun interface") - dupTunFd, err := unix.Dup(t.tunFd) + var tunDevice tun.Device + var err error + + // Validate the tunnel file descriptor. + // On iOS/tvOS, the FD must be provided by the NEPacketTunnelProvider. + // A value of 0 means the Swift code couldn't find the utun control socket + // (the low-level APIs like ctl_info, sockaddr_ctl may not be exposed in + // tvOS SDK headers). This is a hard error - there's no viable fallback + // since tun.CreateTUN() cannot work within the iOS/tvOS sandbox. + if t.tunFd == 0 { + log.Errorf("Tunnel file descriptor is 0 - Swift code failed to locate the utun control socket. " + + "On tvOS, ensure the NEPacketTunnelProvider is properly configured and the tunnel is started.") + return nil, ErrInvalidTunnelFD + } + + // Normal iOS/tvOS path: use the provided file descriptor from NEPacketTunnelProvider + var dupTunFd int + dupTunFd, err = unix.Dup(t.tunFd) if err != nil { log.Errorf("Unable to dup tun fd: %v", err) return nil, err @@ -60,7 +82,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) { _ = unix.Close(dupTunFd) return nil, err } - tunDevice, err := tun.CreateTUNFromFile(os.NewFile(uintptr(dupTunFd), "/dev/tun"), 0) + tunDevice, err = tun.CreateTUNFromFile(os.NewFile(uintptr(dupTunFd), "/dev/tun"), 0) if err != nil { log.Errorf("Unable to create new tun device from fd: %v", err) _ = unix.Close(dupTunFd) diff --git a/client/iface/wgproxy/factory_kernel.go b/client/iface/wgproxy/factory_kernel.go index ad2807546..2714c5774 100644 --- a/client/iface/wgproxy/factory_kernel.go +++ b/client/iface/wgproxy/factory_kernel.go @@ -3,12 +3,19 @@ package wgproxy import ( + "os" + "strconv" + log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/iface/wgproxy/ebpf" udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp" ) +const ( + envDisableEBPFWGProxy = "NB_DISABLE_EBPF_WG_PROXY" +) + type KernelFactory struct { wgPort int mtu uint16 @@ -22,6 +29,12 @@ func NewKernelFactory(wgPort int, mtu uint16) *KernelFactory { mtu: mtu, } + if isEBPFDisabled() { + log.Infof("WireGuard Proxy Factory will produce UDP proxy") + log.Infof("eBPF WireGuard proxy is disabled via %s environment variable", envDisableEBPFWGProxy) + return f + } + ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, mtu) if err := ebpfProxy.Listen(); err != nil { log.Infof("WireGuard Proxy Factory will produce UDP proxy") @@ -47,3 +60,16 @@ func (w *KernelFactory) Free() error { } return w.ebpfProxy.Free() } + +func isEBPFDisabled() bool { + val := os.Getenv(envDisableEBPFWGProxy) + if val == "" { + return false + } + disabled, err := strconv.ParseBool(val) + if err != nil { + log.Warnf("failed to parse %s: %v", envDisableEBPFWGProxy, err) + return false + } + return disabled +} diff --git a/client/internal/connect.go b/client/internal/connect.go index e9d422a28..017c8bf10 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -24,10 +24,14 @@ import ( "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" + "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/stdnet" + "github.com/netbirdio/netbird/client/internal/updatemanager" + "github.com/netbirdio/netbird/client/internal/updatemanager/installer" nbnet "github.com/netbirdio/netbird/client/net" cProto "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/ssh" + sshconfig "github.com/netbirdio/netbird/client/ssh/config" "github.com/netbirdio/netbird/client/system" mgm "github.com/netbirdio/netbird/shared/management/client" mgmProto "github.com/netbirdio/netbird/shared/management/proto" @@ -39,11 +43,13 @@ import ( ) type ConnectClient struct { - ctx context.Context - config *profilemanager.Config - statusRecorder *peer.Status - engine *Engine - engineMutex sync.Mutex + ctx context.Context + config *profilemanager.Config + statusRecorder *peer.Status + doInitialAutoUpdate bool + + engine *Engine + engineMutex sync.Mutex persistSyncResponse bool } @@ -52,13 +58,15 @@ func NewConnectClient( ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, + doInitalAutoUpdate bool, ) *ConnectClient { return &ConnectClient{ - ctx: ctx, - config: config, - statusRecorder: statusRecorder, - engineMutex: sync.Mutex{}, + ctx: ctx, + config: config, + statusRecorder: statusRecorder, + doInitialAutoUpdate: doInitalAutoUpdate, + engineMutex: sync.Mutex{}, } } @@ -162,6 +170,33 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan return err } + var path string + if runtime.GOOS == "ios" || runtime.GOOS == "android" { + // On mobile, use the provided state file path directly + if !fileExists(mobileDependency.StateFilePath) { + if err := createFile(mobileDependency.StateFilePath); err != nil { + log.Errorf("failed to create state file: %v", err) + // we are not exiting as we can run without the state manager + } + } + path = mobileDependency.StateFilePath + } else { + sm := profilemanager.NewServiceManager("") + path = sm.GetStatePath() + } + stateManager := statemanager.New(path) + stateManager.RegisterState(&sshconfig.ShutdownState{}) + + updateManager, err := updatemanager.NewManager(c.statusRecorder, stateManager) + if err == nil { + updateManager.CheckUpdateSuccess(c.ctx) + + inst := installer.New() + if err := inst.CleanUpInstallerFiles(); err != nil { + log.Errorf("failed to clean up temporary installer file: %v", err) + } + } + defer c.statusRecorder.ClientStop() operation := func() error { // if context cancelled we not start new backoff cycle @@ -273,7 +308,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan checks := loginResp.GetChecks() c.engineMutex.Lock() - engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks) + engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, stateManager) engine.SetSyncResponsePersistence(c.persistSyncResponse) c.engine = engine c.engineMutex.Unlock() @@ -283,6 +318,15 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan return wrapErr(err) } + if loginResp.PeerConfig != nil && loginResp.PeerConfig.AutoUpdate != nil { + // AutoUpdate will be true when the user click on "Connect" menu on the UI + if c.doInitialAutoUpdate { + log.Infof("start engine by ui, run auto-update check") + c.engine.InitialUpdateHandling(loginResp.PeerConfig.AutoUpdate) + c.doInitialAutoUpdate = false + } + } + log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress()) state.Set(StatusConnected) diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index 3c201ecfc..01a0377a5 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -27,6 +27,7 @@ import ( "github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" + "github.com/netbirdio/netbird/client/internal/updatemanager/installer" mgmProto "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/util" ) @@ -362,6 +363,10 @@ func (g *BundleGenerator) createArchive() error { log.Errorf("failed to add systemd logs: %v", err) } + if err := g.addUpdateLogs(); err != nil { + log.Errorf("failed to add updater logs: %v", err) + } + return nil } @@ -650,6 +655,29 @@ func (g *BundleGenerator) addStateFile() error { return nil } +func (g *BundleGenerator) addUpdateLogs() error { + inst := installer.New() + logFiles := inst.LogFiles() + if len(logFiles) == 0 { + return nil + } + + log.Infof("adding updater logs") + for _, logFile := range logFiles { + data, err := os.ReadFile(logFile) + if err != nil { + log.Warnf("failed to read update log file %s: %v", logFile, err) + continue + } + + baseName := filepath.Base(logFile) + if err := g.addFileToZip(bytes.NewReader(data), filepath.Join("update-logs", baseName)); err != nil { + return fmt.Errorf("add update log file %s to zip: %w", baseName, err) + } + } + return nil +} + func (g *BundleGenerator) addCorruptedStateFiles() error { sm := profilemanager.NewServiceManager("") pattern := sm.GetStatePath() diff --git a/client/internal/dns/mgmt/mgmt.go b/client/internal/dns/mgmt/mgmt.go index 290395473..d01be0c2c 100644 --- a/client/internal/dns/mgmt/mgmt.go +++ b/client/internal/dns/mgmt/mgmt.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "net/netip" "net/url" "strings" "sync" @@ -26,6 +27,11 @@ type Resolver struct { mutex sync.RWMutex } +type ipsResponse struct { + ips []netip.Addr + err error +} + // NewResolver creates a new management domains cache resolver. func NewResolver() *Resolver { return &Resolver{ @@ -99,9 +105,9 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { ctx, cancel := context.WithTimeout(ctx, dnsTimeout) defer cancel() - ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString()) + ips, err := lookupIPWithExtraTimeout(ctx, d) if err != nil { - return fmt.Errorf("resolve domain %s: %w", d.SafeString(), err) + return err } var aRecords, aaaaRecords []dns.RR @@ -159,6 +165,36 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { return nil } +func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) { + log.Infof("looking up IP for mgmt domain=%s", d.SafeString()) + defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString()) + resultChan := make(chan *ipsResponse, 1) + + go func() { + ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString()) + resultChan <- &ipsResponse{ + err: err, + ips: ips, + } + }() + + var resp *ipsResponse + + select { + case <-time.After(dnsTimeout + time.Millisecond*500): + log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString()) + return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString()) + case <-ctx.Done(): + return nil, ctx.Err() + case resp = <-resultChan: + } + + if resp.err != nil { + return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err) + } + return resp.ips, nil +} + // PopulateFromConfig extracts and caches domains from the client configuration. func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) error { if mgmtURL == nil { diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index afaf0579f..94945b55a 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -80,6 +80,7 @@ type DefaultServer struct { updateSerial uint64 previousConfigHash uint64 currentConfig HostDNSConfig + currentConfigHash uint64 handlerChain *HandlerChain extraDomains map[domain.Domain]int @@ -207,6 +208,7 @@ func newDefaultServer( hostsDNSHolder: newHostsDNSHolder(), hostManager: &noopHostConfigurator{}, mgmtCacheResolver: mgmtCacheResolver, + currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied } // register with root zone, handler chain takes care of the routing @@ -586,8 +588,29 @@ func (s *DefaultServer) applyHostConfig() { log.Debugf("extra match domains: %v", maps.Keys(s.extraDomains)) + hash, err := hashstructure.Hash(config, hashstructure.FormatV2, &hashstructure.HashOptions{ + ZeroNil: true, + IgnoreZeroValue: true, + SlicesAsSets: true, + UseStringer: true, + }) + if err != nil { + log.Warnf("unable to hash the host dns configuration, will apply config anyway: %s", err) + // Fall through to apply config anyway (fail-safe approach) + } else if s.currentConfigHash == hash { + log.Debugf("not applying host config as there are no changes") + return + } + + log.Debugf("applying host config as there are changes") if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil { log.Errorf("failed to apply DNS host manager update: %v", err) + return + } + + // Only update hash if it was computed successfully and config was applied + if err == nil { + s.currentConfigHash = hash } s.registerFallback(config) diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index d12070128..fe1f67f66 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -1602,7 +1602,10 @@ func TestExtraDomains(t *testing.T) { "other.example.com.", "duplicate.example.com.", }, - applyHostConfigCall: 4, + // Expect 3 calls instead of 4 because when deregistering duplicate.example.com, + // the domain remains in the config (ref count goes from 2 to 1), so the host + // config hash doesn't change and applyDNSConfig is not called. + applyHostConfigCall: 3, }, { name: "Config update with new domains after registration", @@ -1657,7 +1660,10 @@ func TestExtraDomains(t *testing.T) { expectedMatchOnly: []string{ "extra.example.com.", }, - applyHostConfigCall: 3, + // Expect 2 calls instead of 3 because when deregistering protected.example.com, + // it's removed from extraDomains but still remains in the config (from customZones), + // so the host config hash doesn't change and applyDNSConfig is not called. + applyHostConfigCall: 2, }, { name: "Register domain that is part of nameserver group", diff --git a/client/internal/engine.go b/client/internal/engine.go index ff1cec19a..4f18c3bc8 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -42,14 +42,13 @@ import ( "github.com/netbirdio/netbird/client/internal/peer/guard" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/peerstore" - "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/rosenpass" "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/client/internal/updatemanager" cProto "github.com/netbirdio/netbird/client/proto" - sshconfig "github.com/netbirdio/netbird/client/ssh/config" "github.com/netbirdio/netbird/shared/management/domain" semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" @@ -73,6 +72,7 @@ const ( PeerConnectionTimeoutMax = 45000 // ms PeerConnectionTimeoutMin = 30000 // ms connInitLimit = 200 + disableAutoUpdate = "disabled" ) var ErrResetConnection = fmt.Errorf("reset connection") @@ -201,6 +201,9 @@ type Engine struct { connSemaphore *semaphoregroup.SemaphoreGroup flowManager nftypes.FlowManager + // auto-update + updateManager *updatemanager.Manager + // WireGuard interface monitor wgIfaceMonitor *WGIfaceMonitor @@ -221,17 +224,7 @@ type localIpUpdater interface { } // NewEngine creates a new Connection Engine with probes attached -func NewEngine( - clientCtx context.Context, - clientCancel context.CancelFunc, - signalClient signal.Client, - mgmClient mgm.Client, - relayManager *relayClient.Manager, - config *EngineConfig, - mobileDep MobileDependency, - statusRecorder *peer.Status, - checks []*mgmProto.Checks, -) *Engine { +func NewEngine(clientCtx context.Context, clientCancel context.CancelFunc, signalClient signal.Client, mgmClient mgm.Client, relayManager *relayClient.Manager, config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status, checks []*mgmProto.Checks, stateManager *statemanager.Manager) *Engine { engine := &Engine{ clientCtx: clientCtx, clientCancel: clientCancel, @@ -247,28 +240,12 @@ func NewEngine( TURNs: []*stun.URI{}, networkSerial: 0, statusRecorder: statusRecorder, + stateManager: stateManager, checks: checks, connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL), } - sm := profilemanager.NewServiceManager("") - - path := sm.GetStatePath() - if runtime.GOOS == "ios" || runtime.GOOS == "android" { - if !fileExists(mobileDep.StateFilePath) { - err := createFile(mobileDep.StateFilePath) - if err != nil { - log.Errorf("failed to create state file: %v", err) - // we are not exiting as we can run without the state manager - } - } - - path = mobileDep.StateFilePath - } - engine.stateManager = statemanager.New(path) - engine.stateManager.RegisterState(&sshconfig.ShutdownState{}) - log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String()) return engine } @@ -308,6 +285,10 @@ func (e *Engine) Stop() error { e.srWatcher.Close() } + if e.updateManager != nil { + e.updateManager.Stop() + } + log.Info("cleaning up status recorder states") e.statusRecorder.ReplaceOfflinePeers([]peer.State{}) e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{}) @@ -541,6 +522,13 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) return nil } +func (e *Engine) InitialUpdateHandling(autoUpdateSettings *mgmProto.AutoUpdateSettings) { + e.syncMsgMux.Lock() + defer e.syncMsgMux.Unlock() + + e.handleAutoUpdateVersion(autoUpdateSettings, true) +} + func (e *Engine) createFirewall() error { if e.config.DisableFirewall { log.Infof("firewall is disabled") @@ -749,6 +737,41 @@ func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mg return nil } +func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings, initialCheck bool) { + if autoUpdateSettings == nil { + return + } + + disabled := autoUpdateSettings.Version == disableAutoUpdate + + // Stop and cleanup if disabled + if e.updateManager != nil && disabled { + log.Infof("auto-update is disabled, stopping update manager") + e.updateManager.Stop() + e.updateManager = nil + return + } + + // Skip check unless AlwaysUpdate is enabled or this is the initial check at startup + if !autoUpdateSettings.AlwaysUpdate && !initialCheck { + log.Debugf("skipping auto-update check, AlwaysUpdate is false and this is not the initial check") + return + } + + // Start manager if needed + if e.updateManager == nil { + log.Infof("starting auto-update manager") + updateManager, err := updatemanager.NewManager(e.statusRecorder, e.stateManager) + if err != nil { + return + } + e.updateManager = updateManager + e.updateManager.Start(e.ctx) + } + log.Infof("handling auto-update version: %s", autoUpdateSettings.Version) + e.updateManager.SetVersion(autoUpdateSettings.Version) +} + func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() @@ -758,6 +781,10 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { return e.ctx.Err() } + if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil { + e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate, false) + } + if update.GetNetbirdConfig() != nil { wCfg := update.GetNetbirdConfig() err := e.updateTURNs(wCfg.GetTurns()) @@ -1094,6 +1121,15 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { e.updateOfflinePeers(networkMap.GetOfflinePeers()) + // Filter out own peer from the remote peers list + localPubKey := e.config.WgPrivateKey.PublicKey().String() + remotePeers := make([]*mgmProto.RemotePeerConfig, 0, len(networkMap.GetRemotePeers())) + for _, p := range networkMap.GetRemotePeers() { + if p.GetWgPubKey() != localPubKey { + remotePeers = append(remotePeers, p) + } + } + // cleanup request, most likely our peer has been deleted if networkMap.GetRemotePeersIsEmpty() { err := e.removeAllPeers() @@ -1102,32 +1138,34 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { return err } } else { - err := e.removePeers(networkMap.GetRemotePeers()) + err := e.removePeers(remotePeers) if err != nil { return err } - err = e.modifyPeers(networkMap.GetRemotePeers()) + err = e.modifyPeers(remotePeers) if err != nil { return err } - err = e.addNewPeers(networkMap.GetRemotePeers()) + err = e.addNewPeers(remotePeers) if err != nil { return err } e.statusRecorder.FinishPeerListModifications() - e.updatePeerSSHHostKeys(networkMap.GetRemotePeers()) + e.updatePeerSSHHostKeys(remotePeers) - if err := e.updateSSHClientConfig(networkMap.GetRemotePeers()); err != nil { + if err := e.updateSSHClientConfig(remotePeers); err != nil { log.Warnf("failed to update SSH client config: %v", err) } + + e.updateSSHServerAuth(networkMap.GetSshAuth()) } // must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store - excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, networkMap.GetRemotePeers()) + excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers) e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers) e.networkSerial = serial diff --git a/client/internal/engine_ssh.go b/client/internal/engine_ssh.go index 861b3d6d2..e683d8cee 100644 --- a/client/internal/engine_ssh.go +++ b/client/internal/engine_ssh.go @@ -11,15 +11,18 @@ import ( firewallManager "github.com/netbirdio/netbird/client/firewall/manager" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" + sshauth "github.com/netbirdio/netbird/client/ssh/auth" sshconfig "github.com/netbirdio/netbird/client/ssh/config" sshserver "github.com/netbirdio/netbird/client/ssh/server" mgmProto "github.com/netbirdio/netbird/shared/management/proto" + sshuserhash "github.com/netbirdio/netbird/shared/sshauth" ) type sshServer interface { Start(ctx context.Context, addr netip.AddrPort) error Stop() error GetStatus() (bool, []sshserver.SessionInfo) + UpdateSSHAuth(config *sshauth.Config) } func (e *Engine) setupSSHPortRedirection() error { @@ -353,3 +356,38 @@ func (e *Engine) GetSSHServerStatus() (enabled bool, sessions []sshserver.Sessio return sshServer.GetStatus() } + +// updateSSHServerAuth updates SSH fine-grained access control configuration on a running SSH server +func (e *Engine) updateSSHServerAuth(sshAuth *mgmProto.SSHAuth) { + if sshAuth == nil { + return + } + + if e.sshServer == nil { + return + } + + protoUsers := sshAuth.GetAuthorizedUsers() + authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers)) + for i, hash := range protoUsers { + if len(hash) != 16 { + log.Warnf("invalid hash length %d, expected 16 - skipping SSH server auth update", len(hash)) + return + } + authorizedUsers[i] = sshuserhash.UserIDHash(hash) + } + + machineUsers := make(map[string][]uint32) + for osUser, indexes := range sshAuth.GetMachineUsers() { + machineUsers[osUser] = indexes.GetIndexes() + } + + // Update SSH server with new authorization configuration + authConfig := &sshauth.Config{ + UserIDClaim: sshAuth.GetUserIDClaim(), + AuthorizedUsers: authorizedUsers, + MachineUsers: machineUsers, + } + + e.sshServer.UpdateSSHAuth(authConfig) +} diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 5ab21e3e1..a15ee0581 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -253,6 +253,7 @@ func TestEngine_SSH(t *testing.T) { MobileDependency{}, peer.NewRecorder("https://mgm"), nil, + nil, ) engine.dnsServer = &dns.MockServer{ @@ -414,21 +415,13 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { defer cancel() relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) - engine := NewEngine( - ctx, cancel, - &signal.MockClient{}, - &mgmt.MockClient{}, - relayMgr, - &EngineConfig{ - WgIfaceName: "utun102", - WgAddr: "100.64.0.1/24", - WgPrivateKey: key, - WgPort: 33100, - MTU: iface.DefaultMTU, - }, - MobileDependency{}, - peer.NewRecorder("https://mgm"), - nil) + engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{ + WgIfaceName: "utun102", + WgAddr: "100.64.0.1/24", + WgPrivateKey: key, + WgPort: 33100, + MTU: iface.DefaultMTU, + }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) wgIface := &MockWGIface{ NameFunc: func() string { return "utun102" }, @@ -647,7 +640,7 @@ func TestEngine_Sync(t *testing.T) { WgPrivateKey: key, WgPort: 33100, MTU: iface.DefaultMTU, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) + }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) engine.ctx = ctx engine.dnsServer = &dns.MockServer{ @@ -812,7 +805,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { WgPrivateKey: key, WgPort: 33100, MTU: iface.DefaultMTU, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) + }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) engine.ctx = ctx newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { @@ -1014,7 +1007,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { WgPrivateKey: key, WgPort: 33100, MTU: iface.DefaultMTU, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) + }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) engine.ctx = ctx newNet, err := stdnet.NewNet(context.Background(), nil) @@ -1540,7 +1533,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin } relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) - e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil + e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil), nil e.ctx = ctx return e, err } @@ -1638,7 +1631,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri if err != nil { return nil, "", err } - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil) if err != nil { return nil, "", err } diff --git a/client/internal/networkmonitor/check_change_darwin.go b/client/internal/networkmonitor/check_change_darwin.go index ddc6e1736..cb5236070 100644 --- a/client/internal/networkmonitor/check_change_darwin.go +++ b/client/internal/networkmonitor/check_change_darwin.go @@ -110,7 +110,6 @@ func wakeUpListen(ctx context.Context) { } if newHash == initialHash { - log.Tracef("no wakeup detected") continue } diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 426c31e1a..20a2eb342 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -148,13 +148,15 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) { // It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will // be used. func (conn *Conn) Open(engineCtx context.Context) error { - conn.semaphore.Add(engineCtx) + if err := conn.semaphore.Add(engineCtx); err != nil { + return err + } conn.mu.Lock() defer conn.mu.Unlock() if conn.opened { - conn.semaphore.Done(engineCtx) + conn.semaphore.Done() return nil } @@ -165,6 +167,7 @@ func (conn *Conn) Open(engineCtx context.Context) error { relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally) if err != nil { + conn.semaphore.Done() return err } conn.workerICE = workerICE @@ -200,7 +203,7 @@ func (conn *Conn) Open(engineCtx context.Context) error { defer conn.wg.Done() conn.waitInitialRandomSleepTime(conn.ctx) - conn.semaphore.Done(conn.ctx) + conn.semaphore.Done() conn.guard.Start(conn.ctx, conn.onGuardEvent) }() diff --git a/client/internal/profilemanager/config.go b/client/internal/profilemanager/config.go index 8f467a214..f2fda84e0 100644 --- a/client/internal/profilemanager/config.go +++ b/client/internal/profilemanager/config.go @@ -3,9 +3,11 @@ package profilemanager import ( "context" "crypto/tls" + "encoding/json" "fmt" "net/url" "os" + "os/user" "path/filepath" "reflect" "runtime" @@ -165,19 +167,26 @@ func getConfigDir() (string, error) { if ConfigDirOverride != "" { return ConfigDirOverride, nil } - configDir, err := os.UserConfigDir() + + base, err := baseConfigDir() if err != nil { return "", err } - configDir = filepath.Join(configDir, "netbird") - if _, err := os.Stat(configDir); os.IsNotExist(err) { - if err := os.MkdirAll(configDir, 0755); err != nil { - return "", err + configDir := filepath.Join(base, "netbird") + if err := os.MkdirAll(configDir, 0o755); err != nil { + return "", err + } + return configDir, nil +} + +func baseConfigDir() (string, error) { + if runtime.GOOS == "darwin" { + if u, err := user.Current(); err == nil && u.HomeDir != "" { + return filepath.Join(u.HomeDir, "Library", "Application Support"), nil } } - - return configDir, nil + return os.UserConfigDir() } func getConfigDirForUser(username string) (string, error) { @@ -676,7 +685,7 @@ func update(input ConfigInput) (*Config, error) { return config, nil } -// GetConfig read config file and return with Config. Errors out if it does not exist +// GetConfig read config file and return with Config and if it was created. Errors out if it does not exist func GetConfig(configPath string) (*Config, error) { return readConfig(configPath, false) } @@ -812,3 +821,85 @@ func readConfig(configPath string, createIfMissing bool) (*Config, error) { func WriteOutConfig(path string, config *Config) error { return util.WriteJson(context.Background(), path, config) } + +// DirectWriteOutConfig writes config directly without atomic temp file operations. +// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox). +func DirectWriteOutConfig(path string, config *Config) error { + return util.DirectWriteJson(context.Background(), path, config) +} + +// DirectUpdateOrCreateConfig is like UpdateOrCreateConfig but uses direct (non-atomic) writes. +// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox). +func DirectUpdateOrCreateConfig(input ConfigInput) (*Config, error) { + if !fileExists(input.ConfigPath) { + log.Infof("generating new config %s", input.ConfigPath) + cfg, err := createNewConfig(input) + if err != nil { + return nil, err + } + err = util.DirectWriteJson(context.Background(), input.ConfigPath, cfg) + return cfg, err + } + + if isPreSharedKeyHidden(input.PreSharedKey) { + input.PreSharedKey = nil + } + + // Enforce permissions on existing config files (same as UpdateOrCreateConfig) + if err := util.EnforcePermission(input.ConfigPath); err != nil { + log.Errorf("failed to enforce permission on config file: %v", err) + } + + return directUpdate(input) +} + +func directUpdate(input ConfigInput) (*Config, error) { + config := &Config{} + + if _, err := util.ReadJson(input.ConfigPath, config); err != nil { + return nil, err + } + + updated, err := config.apply(input) + if err != nil { + return nil, err + } + + if updated { + if err := util.DirectWriteJson(context.Background(), input.ConfigPath, config); err != nil { + return nil, err + } + } + + return config, nil +} + +// ConfigToJSON serializes a Config struct to a JSON string. +// This is useful for exporting config to alternative storage mechanisms +// (e.g., UserDefaults on tvOS where file writes are blocked). +func ConfigToJSON(config *Config) (string, error) { + bs, err := json.MarshalIndent(config, "", " ") + if err != nil { + return "", err + } + return string(bs), nil +} + +// ConfigFromJSON deserializes a JSON string to a Config struct. +// This is useful for restoring config from alternative storage mechanisms. +// After unmarshaling, defaults are applied to ensure the config is fully initialized. +func ConfigFromJSON(jsonStr string) (*Config, error) { + config := &Config{} + err := json.Unmarshal([]byte(jsonStr), config) + if err != nil { + return nil, err + } + + // Apply defaults to ensure required fields are initialized. + // This mirrors what readConfig does after loading from file. + if _, err := config.apply(ConfigInput{}); err != nil { + return nil, fmt.Errorf("failed to apply defaults to config: %w", err) + } + + return config, nil +} diff --git a/client/internal/profilemanager/service.go b/client/internal/profilemanager/service.go index faccf5f68..bdb722c67 100644 --- a/client/internal/profilemanager/service.go +++ b/client/internal/profilemanager/service.go @@ -76,6 +76,7 @@ func (a *ActiveProfileState) FilePath() (string, error) { } type ServiceManager struct { + profilesDir string // If set, overrides ConfigDirOverride for profile operations } func NewServiceManager(defaultConfigPath string) *ServiceManager { @@ -85,6 +86,17 @@ func NewServiceManager(defaultConfigPath string) *ServiceManager { return &ServiceManager{} } +// NewServiceManagerWithProfilesDir creates a ServiceManager with a specific profiles directory +// This allows setting the profiles directory without modifying the global ConfigDirOverride +func NewServiceManagerWithProfilesDir(defaultConfigPath string, profilesDir string) *ServiceManager { + if defaultConfigPath != "" { + DefaultConfigPath = defaultConfigPath + } + return &ServiceManager{ + profilesDir: profilesDir, + } +} + func (s *ServiceManager) CopyDefaultProfileIfNotExists() (bool, error) { if err := os.MkdirAll(DefaultConfigPathDir, 0600); err != nil { @@ -114,14 +126,6 @@ func (s *ServiceManager) CopyDefaultProfileIfNotExists() (bool, error) { log.Warnf("failed to set permissions for default profile: %v", err) } - if err := s.SetActiveProfileState(&ActiveProfileState{ - Name: "default", - Username: "", - }); err != nil { - log.Errorf("failed to set active profile state: %v", err) - return false, fmt.Errorf("failed to set active profile state: %w", err) - } - return true, nil } @@ -240,7 +244,7 @@ func (s *ServiceManager) DefaultProfilePath() string { } func (s *ServiceManager) AddProfile(profileName, username string) error { - configDir, err := getConfigDirForUser(username) + configDir, err := s.getConfigDir(username) if err != nil { return fmt.Errorf("failed to get config directory: %w", err) } @@ -270,7 +274,7 @@ func (s *ServiceManager) AddProfile(profileName, username string) error { } func (s *ServiceManager) RemoveProfile(profileName, username string) error { - configDir, err := getConfigDirForUser(username) + configDir, err := s.getConfigDir(username) if err != nil { return fmt.Errorf("failed to get config directory: %w", err) } @@ -302,7 +306,7 @@ func (s *ServiceManager) RemoveProfile(profileName, username string) error { } func (s *ServiceManager) ListProfiles(username string) ([]Profile, error) { - configDir, err := getConfigDirForUser(username) + configDir, err := s.getConfigDir(username) if err != nil { return nil, fmt.Errorf("failed to get config directory: %w", err) } @@ -361,7 +365,7 @@ func (s *ServiceManager) GetStatePath() string { return defaultStatePath } - configDir, err := getConfigDirForUser(activeProf.Username) + configDir, err := s.getConfigDir(activeProf.Username) if err != nil { log.Warnf("failed to get config directory for user %s: %v", activeProf.Username, err) return defaultStatePath @@ -369,3 +373,12 @@ func (s *ServiceManager) GetStatePath() string { return filepath.Join(configDir, activeProf.Name+".state.json") } + +// getConfigDir returns the profiles directory, using profilesDir if set, otherwise getConfigDirForUser +func (s *ServiceManager) getConfigDir(username string) (string, error) { + if s.profilesDir != "" { + return s.profilesDir, nil + } + + return getConfigDirForUser(username) +} diff --git a/client/internal/updatemanager/doc.go b/client/internal/updatemanager/doc.go new file mode 100644 index 000000000..54d1bdeab --- /dev/null +++ b/client/internal/updatemanager/doc.go @@ -0,0 +1,35 @@ +// Package updatemanager provides automatic update management for the NetBird client. +// It monitors for new versions, handles update triggers from management server directives, +// and orchestrates the download and installation of client updates. +// +// # Overview +// +// The update manager operates as a background service that continuously monitors for +// available updates and automatically initiates the update process when conditions are met. +// It integrates with the installer package to perform the actual installation. +// +// # Update Flow +// +// The complete update process follows these steps: +// +// 1. Manager receives update directive via SetVersion() or detects new version +// 2. Manager validates update should proceed (version comparison, rate limiting) +// 3. Manager publishes "updating" event to status recorder +// 4. Manager persists UpdateState to track update attempt +// 5. Manager downloads installer file (.msi or .exe) to temporary directory +// 6. Manager triggers installation via installer.RunInstallation() +// 7. Installer package handles the actual installation process +// 8. On next startup, CheckUpdateSuccess() verifies update completion +// 9. Manager publishes success/failure event to status recorder +// 10. Manager cleans up UpdateState +// +// # State Management +// +// Update state is persisted across restarts to track update attempts: +// +// - PreUpdateVersion: Version before update attempt +// - TargetVersion: Version attempting to update to +// +// This enables verification of successful updates and appropriate user notification +// after the client restarts with the new version. +package updatemanager diff --git a/client/internal/updatemanager/downloader/downloader.go b/client/internal/updatemanager/downloader/downloader.go new file mode 100644 index 000000000..2ac36efed --- /dev/null +++ b/client/internal/updatemanager/downloader/downloader.go @@ -0,0 +1,138 @@ +package downloader + +import ( + "context" + "fmt" + "io" + "net/http" + "os" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/version" +) + +const ( + userAgent = "NetBird agent installer/%s" + DefaultRetryDelay = 3 * time.Second +) + +func DownloadToFile(ctx context.Context, retryDelay time.Duration, url, dstFile string) error { + log.Debugf("starting download from %s", url) + + out, err := os.Create(dstFile) + if err != nil { + return fmt.Errorf("failed to create destination file %q: %w", dstFile, err) + } + defer func() { + if cerr := out.Close(); cerr != nil { + log.Warnf("error closing file %q: %v", dstFile, cerr) + } + }() + + // First attempt + err = downloadToFileOnce(ctx, url, out) + if err == nil { + log.Infof("successfully downloaded file to %s", dstFile) + return nil + } + + // If retryDelay is 0, don't retry + if retryDelay == 0 { + return err + } + + log.Warnf("download failed, retrying after %v: %v", retryDelay, err) + + // Sleep before retry + if sleepErr := sleepWithContext(ctx, retryDelay); sleepErr != nil { + return fmt.Errorf("download cancelled during retry delay: %w", sleepErr) + } + + // Truncate file before retry + if err := out.Truncate(0); err != nil { + return fmt.Errorf("failed to truncate file on retry: %w", err) + } + if _, err := out.Seek(0, 0); err != nil { + return fmt.Errorf("failed to seek to beginning of file: %w", err) + } + + // Second attempt + if err := downloadToFileOnce(ctx, url, out); err != nil { + return fmt.Errorf("download failed after retry: %w", err) + } + + log.Infof("successfully downloaded file to %s", dstFile) + return nil +} + +func DownloadToMemory(ctx context.Context, url string, limit int64) ([]byte, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP request: %w", err) + } + + // Add User-Agent header + req.Header.Set("User-Agent", fmt.Sprintf(userAgent, version.NetbirdVersion())) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to perform HTTP request: %w", err) + } + defer func() { + if cerr := resp.Body.Close(); cerr != nil { + log.Warnf("error closing response body: %v", cerr) + } + }() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected HTTP status: %d", resp.StatusCode) + } + + data, err := io.ReadAll(io.LimitReader(resp.Body, limit)) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + return data, nil +} + +func downloadToFileOnce(ctx context.Context, url string, out *os.File) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return fmt.Errorf("failed to create HTTP request: %w", err) + } + + // Add User-Agent header + req.Header.Set("User-Agent", fmt.Sprintf(userAgent, version.NetbirdVersion())) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Errorf("failed to perform HTTP request: %w", err) + } + defer func() { + if cerr := resp.Body.Close(); cerr != nil { + log.Warnf("error closing response body: %v", cerr) + } + }() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("unexpected HTTP status: %d", resp.StatusCode) + } + + if _, err := io.Copy(out, resp.Body); err != nil { + return fmt.Errorf("failed to write response body to file: %w", err) + } + + return nil +} + +func sleepWithContext(ctx context.Context, duration time.Duration) error { + select { + case <-time.After(duration): + return nil + case <-ctx.Done(): + return ctx.Err() + } +} diff --git a/client/internal/updatemanager/downloader/downloader_test.go b/client/internal/updatemanager/downloader/downloader_test.go new file mode 100644 index 000000000..045db3a2d --- /dev/null +++ b/client/internal/updatemanager/downloader/downloader_test.go @@ -0,0 +1,199 @@ +package downloader + +import ( + "context" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "sync/atomic" + "testing" + "time" +) + +const ( + retryDelay = 100 * time.Millisecond +) + +func TestDownloadToFile_Success(t *testing.T) { + // Create a test server that responds successfully + content := "test file content" + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(content)) + })) + defer server.Close() + + // Create a temporary file for download + tempDir := t.TempDir() + dstFile := filepath.Join(tempDir, "downloaded.txt") + + // Download the file + err := DownloadToFile(context.Background(), retryDelay, server.URL, dstFile) + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + // Verify the file content + data, err := os.ReadFile(dstFile) + if err != nil { + t.Fatalf("failed to read downloaded file: %v", err) + } + + if string(data) != content { + t.Errorf("expected content %q, got %q", content, string(data)) + } +} + +func TestDownloadToFile_SuccessAfterRetry(t *testing.T) { + content := "test file content after retry" + var attemptCount atomic.Int32 + + // Create a test server that fails on first attempt, succeeds on second + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempt := attemptCount.Add(1) + if attempt == 1 { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("error")) + return + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(content)) + })) + defer server.Close() + + // Create a temporary file for download + tempDir := t.TempDir() + dstFile := filepath.Join(tempDir, "downloaded.txt") + + // Download the file (should succeed after retry) + if err := DownloadToFile(context.Background(), 10*time.Millisecond, server.URL, dstFile); err != nil { + t.Fatalf("expected no error after retry, got: %v", err) + } + + // Verify the file content + data, err := os.ReadFile(dstFile) + if err != nil { + t.Fatalf("failed to read downloaded file: %v", err) + } + + if string(data) != content { + t.Errorf("expected content %q, got %q", content, string(data)) + } + + // Verify it took 2 attempts + if attemptCount.Load() != 2 { + t.Errorf("expected 2 attempts, got %d", attemptCount.Load()) + } +} + +func TestDownloadToFile_FailsAfterRetry(t *testing.T) { + var attemptCount atomic.Int32 + + // Create a test server that always fails + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attemptCount.Add(1) + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("error")) + })) + defer server.Close() + + // Create a temporary file for download + tempDir := t.TempDir() + dstFile := filepath.Join(tempDir, "downloaded.txt") + + // Download the file (should fail after retry) + if err := DownloadToFile(context.Background(), 10*time.Millisecond, server.URL, dstFile); err == nil { + t.Fatal("expected error after retry, got nil") + } + + // Verify it tried 2 times + if attemptCount.Load() != 2 { + t.Errorf("expected 2 attempts, got %d", attemptCount.Load()) + } +} + +func TestDownloadToFile_ContextCancellationDuringRetry(t *testing.T) { + var attemptCount atomic.Int32 + + // Create a test server that always fails + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attemptCount.Add(1) + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + // Create a temporary file for download + tempDir := t.TempDir() + dstFile := filepath.Join(tempDir, "downloaded.txt") + + // Create a context that will be cancelled during retry delay + ctx, cancel := context.WithCancel(context.Background()) + + // Cancel after a short delay (during the retry sleep) + go func() { + time.Sleep(100 * time.Millisecond) + cancel() + }() + + // Download the file (should fail due to context cancellation during retry) + err := DownloadToFile(ctx, 1*time.Second, server.URL, dstFile) + if err == nil { + t.Fatal("expected error due to context cancellation, got nil") + } + + // Should have only made 1 attempt (cancelled during retry delay) + if attemptCount.Load() != 1 { + t.Errorf("expected 1 attempt, got %d", attemptCount.Load()) + } +} + +func TestDownloadToFile_InvalidURL(t *testing.T) { + tempDir := t.TempDir() + dstFile := filepath.Join(tempDir, "downloaded.txt") + + err := DownloadToFile(context.Background(), retryDelay, "://invalid-url", dstFile) + if err == nil { + t.Fatal("expected error for invalid URL, got nil") + } +} + +func TestDownloadToFile_InvalidDestination(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("test")) + })) + defer server.Close() + + // Use an invalid destination path + err := DownloadToFile(context.Background(), retryDelay, server.URL, "/invalid/path/that/does/not/exist/file.txt") + if err == nil { + t.Fatal("expected error for invalid destination, got nil") + } +} + +func TestDownloadToFile_NoRetry(t *testing.T) { + var attemptCount atomic.Int32 + + // Create a test server that always fails + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attemptCount.Add(1) + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("error")) + })) + defer server.Close() + + // Create a temporary file for download + tempDir := t.TempDir() + dstFile := filepath.Join(tempDir, "downloaded.txt") + + // Download the file with retryDelay = 0 (should not retry) + if err := DownloadToFile(context.Background(), 0, server.URL, dstFile); err == nil { + t.Fatal("expected error, got nil") + } + + // Verify it only made 1 attempt (no retry) + if attemptCount.Load() != 1 { + t.Errorf("expected 1 attempt, got %d", attemptCount.Load()) + } +} diff --git a/client/internal/updatemanager/installer/binary_nowindows.go b/client/internal/updatemanager/installer/binary_nowindows.go new file mode 100644 index 000000000..19f3bef83 --- /dev/null +++ b/client/internal/updatemanager/installer/binary_nowindows.go @@ -0,0 +1,7 @@ +//go:build !windows + +package installer + +func UpdaterBinaryNameWithoutExtension() string { + return updaterBinary +} diff --git a/client/internal/updatemanager/installer/binary_windows.go b/client/internal/updatemanager/installer/binary_windows.go new file mode 100644 index 000000000..4c66391c2 --- /dev/null +++ b/client/internal/updatemanager/installer/binary_windows.go @@ -0,0 +1,11 @@ +package installer + +import ( + "path/filepath" + "strings" +) + +func UpdaterBinaryNameWithoutExtension() string { + ext := filepath.Ext(updaterBinary) + return strings.TrimSuffix(updaterBinary, ext) +} diff --git a/client/internal/updatemanager/installer/doc.go b/client/internal/updatemanager/installer/doc.go new file mode 100644 index 000000000..0a60454bb --- /dev/null +++ b/client/internal/updatemanager/installer/doc.go @@ -0,0 +1,111 @@ +// Package installer provides functionality for managing NetBird application +// updates and installations across Windows, macOS. It handles +// the complete update lifecycle including artifact download, cryptographic verification, +// installation execution, process management, and result reporting. +// +// # Architecture +// +// The installer package uses a two-process architecture to enable self-updates: +// +// 1. Service Process: The main NetBird daemon process that initiates updates +// 2. Updater Process: A detached child process that performs the actual installation +// +// This separation is critical because: +// - The service binary cannot update itself while running +// - The installer (EXE/MSI/PKG) will terminate the service during installation +// - The updater process survives service termination and restarts it after installation +// - Results can be communicated back to the service after it restarts +// +// # Update Flow +// +// Service Process (RunInstallation): +// +// 1. Validates target version format (semver) +// 2. Determines installer type (EXE, MSI, PKG, or Homebrew) +// 3. Downloads installer file from GitHub releases (if applicable) +// 4. Verifies installer signature using reposign package (cryptographic verification in service process before +// launching updater) +// 5. Copies service binary to tempDir as "updater" (or "updater.exe" on Windows) +// 6. Launches updater process with detached mode: +// - --temp-dir: Temporary directory path +// - --service-dir: Service installation directory +// - --installer-file: Path to downloaded installer (if applicable) +// - --dry-run: Optional flag to test without actually installing +// 7. Service process continues running (will be terminated by installer later) +// 8. Service can watch for result.json using ResultHandler.Watch() to detect completion +// +// Updater Process (Setup): +// +// 1. Receives parameters from service via command-line arguments +// 2. Runs installer with appropriate silent/quiet flags: +// - Windows EXE: installer.exe /S +// - Windows MSI: msiexec.exe /i installer.msi /quiet /qn /l*v msi.log +// - macOS PKG: installer -pkg installer.pkg -target / +// - macOS Homebrew: brew upgrade netbirdio/tap/netbird +// 3. Installer terminates daemon and UI processes +// 4. Installer replaces binaries with new version +// 5. Updater waits for installer to complete +// 6. Updater restarts daemon: +// - Windows: netbird.exe service start +// - macOS/Linux: netbird service start +// 7. Updater restarts UI: +// - Windows: Launches netbird-ui.exe as active console user using CreateProcessAsUser +// - macOS: Uses launchctl asuser to launch NetBird.app for console user +// - Linux: Not implemented (UI typically auto-starts) +// 8. Updater writes result.json with success/error status +// 9. Updater process exits +// +// # Result Communication +// +// The ResultHandler (result.go) manages communication between updater and service: +// +// Result Structure: +// +// type Result struct { +// Success bool // true if installation succeeded +// Error string // error message if Success is false +// ExecutedAt time.Time // when installation completed +// } +// +// Result files are automatically cleaned up after being read. +// +// # File Locations +// +// Temporary Directory (platform-specific): +// +// Windows: +// - Path: %ProgramData%\Netbird\tmp-install +// - Example: C:\ProgramData\Netbird\tmp-install +// +// macOS: +// - Path: /var/lib/netbird/tmp-install +// - Requires root permissions +// +// Files created during installation: +// +// tmp-install/ +// installer.log +// updater[.exe] # Copy of service binary +// netbird_installer_*.[exe|msi|pkg] # Downloaded installer +// result.json # Installation result +// msi.log # MSI verbose log (Windows MSI only) +// +// # API Reference +// +// # Cleanup +// +// CleanUpInstallerFiles() removes temporary files after successful installation: +// - Downloaded installer files (*.exe, *.msi, *.pkg) +// - Updater binary copy +// - Does NOT remove result.json (cleaned by ResultHandler after read) +// - Does NOT remove msi.log (kept for debugging) +// +// # Dry-Run Mode +// +// Dry-run mode allows testing the update process without actually installing: +// +// Enable via environment variable: +// +// export NB_AUTO_UPDATE_DRY_RUN=true +// netbird service install-update 0.29.0 +package installer diff --git a/client/internal/updatemanager/installer/installer.go b/client/internal/updatemanager/installer/installer.go new file mode 100644 index 000000000..caf5873f8 --- /dev/null +++ b/client/internal/updatemanager/installer/installer.go @@ -0,0 +1,50 @@ +//go:build !windows && !darwin + +package installer + +import ( + "context" + "fmt" +) + +const ( + updaterBinary = "updater" +) + +type Installer struct { + tempDir string +} + +// New used by the service +func New() *Installer { + return &Installer{} +} + +// NewWithDir used by the updater process, get the tempDir from the service via cmd line +func NewWithDir(tempDir string) *Installer { + return &Installer{ + tempDir: tempDir, + } +} + +func (u *Installer) TempDir() string { + return "" +} + +func (c *Installer) LogFiles() []string { + return []string{} +} + +func (u *Installer) CleanUpInstallerFiles() error { + return nil +} + +func (u *Installer) RunInstallation(ctx context.Context, targetVersion string) error { + return fmt.Errorf("unsupported platform") +} + +// Setup runs the installer with appropriate arguments and manages the daemon/UI state +// This will be run by the updater process +func (u *Installer) Setup(ctx context.Context, dryRun bool, targetVersion string, daemonFolder string) (resultErr error) { + return fmt.Errorf("unsupported platform") +} diff --git a/client/internal/updatemanager/installer/installer_common.go b/client/internal/updatemanager/installer/installer_common.go new file mode 100644 index 000000000..03378d55f --- /dev/null +++ b/client/internal/updatemanager/installer/installer_common.go @@ -0,0 +1,293 @@ +//go:build windows || darwin + +package installer + +import ( + "context" + "fmt" + "io" + "os" + "os/exec" + "path" + "path/filepath" + "strings" + + "github.com/hashicorp/go-multierror" + goversion "github.com/hashicorp/go-version" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/updatemanager/downloader" + "github.com/netbirdio/netbird/client/internal/updatemanager/reposign" +) + +type Installer struct { + tempDir string +} + +// New used by the service +func New() *Installer { + return &Installer{ + tempDir: defaultTempDir, + } +} + +// NewWithDir used by the updater process, get the tempDir from the service via cmd line +func NewWithDir(tempDir string) *Installer { + return &Installer{ + tempDir: tempDir, + } +} + +// RunInstallation starts the updater process to run the installation +// This will run by the original service process +func (u *Installer) RunInstallation(ctx context.Context, targetVersion string) (err error) { + resultHandler := NewResultHandler(u.tempDir) + + defer func() { + if err != nil { + if writeErr := resultHandler.WriteErr(err); writeErr != nil { + log.Errorf("failed to write error result: %v", writeErr) + } + } + }() + + if err := validateTargetVersion(targetVersion); err != nil { + return err + } + + if err := u.mkTempDir(); err != nil { + return err + } + + var installerFile string + // Download files only when not using any third-party store + if installerType := TypeOfInstaller(ctx); installerType.Downloadable() { + log.Infof("download installer") + var err error + installerFile, err = u.downloadInstaller(ctx, installerType, targetVersion) + if err != nil { + log.Errorf("failed to download installer: %v", err) + return err + } + + artifactVerify, err := reposign.NewArtifactVerify(DefaultSigningKeysBaseURL) + if err != nil { + log.Errorf("failed to create artifact verify: %v", err) + return err + } + + if err := artifactVerify.Verify(ctx, targetVersion, installerFile); err != nil { + log.Errorf("artifact verification error: %v", err) + return err + } + } + + log.Infof("running installer") + updaterPath, err := u.copyUpdater() + if err != nil { + return err + } + + // the directory where the service has been installed + workspace, err := getServiceDir() + if err != nil { + return err + } + + args := []string{ + "--temp-dir", u.tempDir, + "--service-dir", workspace, + } + + if isDryRunEnabled() { + args = append(args, "--dry-run=true") + } + + if installerFile != "" { + args = append(args, "--installer-file", installerFile) + } + + updateCmd := exec.Command(updaterPath, args...) + log.Infof("starting updater process: %s", updateCmd.String()) + + // Configure the updater to run in a separate session/process group + // so it survives the parent daemon being stopped + setUpdaterProcAttr(updateCmd) + + // Start the updater process asynchronously + if err := updateCmd.Start(); err != nil { + return err + } + + pid := updateCmd.Process.Pid + log.Infof("updater started with PID %d", pid) + + // Release the process so the OS can fully detach it + if err := updateCmd.Process.Release(); err != nil { + log.Warnf("failed to release updater process: %v", err) + } + + return nil +} + +// CleanUpInstallerFiles +// - the installer file (pkg, exe, msi) +// - the selfcopy updater.exe +func (u *Installer) CleanUpInstallerFiles() error { + // Check if tempDir exists + info, err := os.Stat(u.tempDir) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + + if !info.IsDir() { + return nil + } + + var merr *multierror.Error + + if err := os.Remove(filepath.Join(u.tempDir, updaterBinary)); err != nil && !os.IsNotExist(err) { + merr = multierror.Append(merr, fmt.Errorf("failed to remove updater binary: %w", err)) + } + + entries, err := os.ReadDir(u.tempDir) + if err != nil { + return err + } + + for _, entry := range entries { + if entry.IsDir() { + continue + } + + name := entry.Name() + for _, ext := range binaryExtensions { + if strings.HasSuffix(strings.ToLower(name), strings.ToLower(ext)) { + if err := os.Remove(filepath.Join(u.tempDir, name)); err != nil { + merr = multierror.Append(merr, fmt.Errorf("failed to remove %s: %w", name, err)) + } + break + } + } + } + + return merr.ErrorOrNil() +} + +func (u *Installer) downloadInstaller(ctx context.Context, installerType Type, targetVersion string) (string, error) { + fileURL := urlWithVersionArch(installerType, targetVersion) + + // Clean up temp directory on error + var success bool + defer func() { + if !success { + if err := os.RemoveAll(u.tempDir); err != nil { + log.Errorf("error cleaning up temporary directory: %v", err) + } + } + }() + + fileName := path.Base(fileURL) + if fileName == "." || fileName == "/" || fileName == "" { + return "", fmt.Errorf("invalid file URL: %s", fileURL) + } + + outputFilePath := filepath.Join(u.tempDir, fileName) + if err := downloader.DownloadToFile(ctx, downloader.DefaultRetryDelay, fileURL, outputFilePath); err != nil { + return "", err + } + + success = true + return outputFilePath, nil +} + +func (u *Installer) TempDir() string { + return u.tempDir +} + +func (u *Installer) mkTempDir() error { + if err := os.MkdirAll(u.tempDir, 0o755); err != nil { + log.Debugf("failed to create tempdir: %s", u.tempDir) + return err + } + return nil +} + +func (u *Installer) copyUpdater() (string, error) { + src, err := getServiceBinary() + if err != nil { + return "", fmt.Errorf("failed to get updater binary: %w", err) + } + + dst := filepath.Join(u.tempDir, updaterBinary) + if err := copyFile(src, dst); err != nil { + return "", fmt.Errorf("failed to copy updater binary: %w", err) + } + + if err := os.Chmod(dst, 0o755); err != nil { + return "", fmt.Errorf("failed to set permissions: %w", err) + } + + return dst, nil +} + +func validateTargetVersion(targetVersion string) error { + if targetVersion == "" { + return fmt.Errorf("target version cannot be empty") + } + + _, err := goversion.NewVersion(targetVersion) + if err != nil { + return fmt.Errorf("invalid target version %q: %w", targetVersion, err) + } + + return nil +} + +func copyFile(src, dst string) error { + log.Infof("copying %s to %s", src, dst) + in, err := os.Open(src) + if err != nil { + return fmt.Errorf("open source: %w", err) + } + defer func() { + if err := in.Close(); err != nil { + log.Warnf("failed to close source file: %v", err) + } + }() + + out, err := os.Create(dst) + if err != nil { + return fmt.Errorf("create destination: %w", err) + } + defer func() { + if err := out.Close(); err != nil { + log.Warnf("failed to close destination file: %v", err) + } + }() + + if _, err := io.Copy(out, in); err != nil { + return fmt.Errorf("copy: %w", err) + } + + return nil +} + +func getServiceDir() (string, error) { + exePath, err := os.Executable() + if err != nil { + return "", err + } + return filepath.Dir(exePath), nil +} + +func getServiceBinary() (string, error) { + return os.Executable() +} + +func isDryRunEnabled() bool { + return strings.EqualFold(strings.TrimSpace(os.Getenv("NB_AUTO_UPDATE_DRY_RUN")), "true") +} diff --git a/client/internal/updatemanager/installer/installer_log_darwin.go b/client/internal/updatemanager/installer/installer_log_darwin.go new file mode 100644 index 000000000..50dd5d197 --- /dev/null +++ b/client/internal/updatemanager/installer/installer_log_darwin.go @@ -0,0 +1,11 @@ +package installer + +import ( + "path/filepath" +) + +func (u *Installer) LogFiles() []string { + return []string{ + filepath.Join(u.tempDir, LogFile), + } +} diff --git a/client/internal/updatemanager/installer/installer_log_windows.go b/client/internal/updatemanager/installer/installer_log_windows.go new file mode 100644 index 000000000..96e4cfd1f --- /dev/null +++ b/client/internal/updatemanager/installer/installer_log_windows.go @@ -0,0 +1,12 @@ +package installer + +import ( + "path/filepath" +) + +func (u *Installer) LogFiles() []string { + return []string{ + filepath.Join(u.tempDir, msiLogFile), + filepath.Join(u.tempDir, LogFile), + } +} diff --git a/client/internal/updatemanager/installer/installer_run_darwin.go b/client/internal/updatemanager/installer/installer_run_darwin.go new file mode 100644 index 000000000..248a404aa --- /dev/null +++ b/client/internal/updatemanager/installer/installer_run_darwin.go @@ -0,0 +1,238 @@ +package installer + +import ( + "context" + "fmt" + "os" + "os/exec" + "os/user" + "path/filepath" + "runtime" + "strings" + "syscall" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + daemonName = "netbird" + updaterBinary = "updater" + uiBinary = "/Applications/NetBird.app" + + defaultTempDir = "/var/lib/netbird/tmp-install" + + pkgDownloadURL = "https://github.com/netbirdio/netbird/releases/download/v%version/netbird_%version_darwin_%arch.pkg" +) + +var ( + binaryExtensions = []string{"pkg"} +) + +// Setup runs the installer with appropriate arguments and manages the daemon/UI state +// This will be run by the updater process +func (u *Installer) Setup(ctx context.Context, dryRun bool, installerFile string, daemonFolder string) (resultErr error) { + resultHandler := NewResultHandler(u.tempDir) + + // Always ensure daemon and UI are restarted after setup + defer func() { + log.Infof("write out result") + var err error + if resultErr == nil { + err = resultHandler.WriteSuccess() + } else { + err = resultHandler.WriteErr(resultErr) + } + if err != nil { + log.Errorf("failed to write update result: %v", err) + } + + // skip service restart if dry-run mode is enabled + if dryRun { + return + } + + log.Infof("starting daemon back") + if err := u.startDaemon(daemonFolder); err != nil { + log.Errorf("failed to start daemon: %v", err) + } + + log.Infof("starting UI back") + if err := u.startUIAsUser(); err != nil { + log.Errorf("failed to start UI: %v", err) + } + + }() + + if dryRun { + time.Sleep(7 * time.Second) + log.Infof("dry-run mode enabled, skipping actual installation") + resultErr = fmt.Errorf("dry-run mode enabled") + return + } + + switch TypeOfInstaller(ctx) { + case TypePKG: + resultErr = u.installPkgFile(ctx, installerFile) + case TypeHomebrew: + resultErr = u.updateHomeBrew(ctx) + } + + return resultErr +} + +func (u *Installer) startDaemon(daemonFolder string) error { + log.Infof("starting netbird service") + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, filepath.Join(daemonFolder, daemonName), "service", "start") + if output, err := cmd.CombinedOutput(); err != nil { + log.Warnf("failed to start netbird service: %v, output: %s", err, string(output)) + return err + } + log.Infof("netbird service started successfully") + return nil +} + +func (u *Installer) startUIAsUser() error { + log.Infof("starting netbird-ui: %s", uiBinary) + + // Get the current console user + cmd := exec.Command("stat", "-f", "%Su", "/dev/console") + output, err := cmd.Output() + if err != nil { + return fmt.Errorf("failed to get console user: %w", err) + } + + username := strings.TrimSpace(string(output)) + if username == "" || username == "root" { + return fmt.Errorf("no active user session found") + } + + log.Infof("starting UI for user: %s", username) + + // Get user's UID + userInfo, err := user.Lookup(username) + if err != nil { + return fmt.Errorf("failed to lookup user %s: %w", username, err) + } + + // Start the UI process as the console user using launchctl + // This ensures the app runs in the user's context with proper GUI access + launchCmd := exec.Command("launchctl", "asuser", userInfo.Uid, "open", "-a", uiBinary) + log.Infof("launchCmd: %s", launchCmd.String()) + // Set the user's home directory for proper macOS app behavior + launchCmd.Env = append(os.Environ(), "HOME="+userInfo.HomeDir) + log.Infof("set HOME environment variable: %s", userInfo.HomeDir) + + if err := launchCmd.Start(); err != nil { + return fmt.Errorf("failed to start UI process: %w", err) + } + + // Release the process so it can run independently + if err := launchCmd.Process.Release(); err != nil { + log.Warnf("failed to release UI process: %v", err) + } + + log.Infof("netbird-ui started successfully for user %s", username) + return nil +} + +func (u *Installer) installPkgFile(ctx context.Context, path string) error { + log.Infof("installing pkg file: %s", path) + + // Kill any existing UI processes before installation + // This ensures the postinstall script's "open $APP" will start the new version + u.killUI() + + volume := "/" + + cmd := exec.CommandContext(ctx, "installer", "-pkg", path, "-target", volume) + if err := cmd.Start(); err != nil { + return fmt.Errorf("error running pkg file: %w", err) + } + log.Infof("installer started with PID %d", cmd.Process.Pid) + if err := cmd.Wait(); err != nil { + return fmt.Errorf("error running pkg file: %w", err) + } + log.Infof("pkg file installed successfully") + return nil +} + +func (u *Installer) updateHomeBrew(ctx context.Context) error { + log.Infof("updating homebrew") + + // Kill any existing UI processes before upgrade + // This ensures the new version will be started after upgrade + u.killUI() + + // Homebrew must be run as a non-root user + // To find out which user installed NetBird using HomeBrew we can check the owner of our brew tap directory + // Check both Apple Silicon and Intel Mac paths + brewTapPath := "/opt/homebrew/Library/Taps/netbirdio/homebrew-tap/" + brewBinPath := "/opt/homebrew/bin/brew" + if _, err := os.Stat(brewTapPath); os.IsNotExist(err) { + // Try Intel Mac path + brewTapPath = "/usr/local/Homebrew/Library/Taps/netbirdio/homebrew-tap/" + brewBinPath = "/usr/local/bin/brew" + } + + fileInfo, err := os.Stat(brewTapPath) + if err != nil { + return fmt.Errorf("error getting homebrew installation path info: %w", err) + } + + fileSysInfo, ok := fileInfo.Sys().(*syscall.Stat_t) + if !ok { + return fmt.Errorf("error checking file owner, sysInfo type is %T not *syscall.Stat_t", fileInfo.Sys()) + } + + // Get username from UID + brewUser, err := user.LookupId(fmt.Sprintf("%d", fileSysInfo.Uid)) + if err != nil { + return fmt.Errorf("error looking up brew installer user: %w", err) + } + userName := brewUser.Username + // Get user HOME, required for brew to run correctly + // https://github.com/Homebrew/brew/issues/15833 + homeDir := brewUser.HomeDir + + // Check if netbird-ui is installed (must run as the brew user, not root) + checkUICmd := exec.CommandContext(ctx, "sudo", "-u", userName, brewBinPath, "list", "--formula", "netbirdio/tap/netbird-ui") + checkUICmd.Env = append(os.Environ(), "HOME="+homeDir) + uiInstalled := checkUICmd.Run() == nil + + // Homebrew does not support installing specific versions + // Thus it will always update to latest and ignore targetVersion + upgradeArgs := []string{"-u", userName, brewBinPath, "upgrade", "netbirdio/tap/netbird"} + if uiInstalled { + upgradeArgs = append(upgradeArgs, "netbirdio/tap/netbird-ui") + } + + cmd := exec.CommandContext(ctx, "sudo", upgradeArgs...) + cmd.Env = append(os.Environ(), "HOME="+homeDir) + + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("error running brew upgrade: %w, output: %s", err, string(output)) + } + + log.Infof("homebrew updated successfully") + return nil +} + +func (u *Installer) killUI() { + log.Infof("killing existing netbird-ui processes") + cmd := exec.Command("pkill", "-x", "netbird-ui") + if output, err := cmd.CombinedOutput(); err != nil { + // pkill returns exit code 1 if no processes matched, which is fine + log.Debugf("pkill netbird-ui result: %v, output: %s", err, string(output)) + } else { + log.Infof("netbird-ui processes killed") + } +} + +func urlWithVersionArch(_ Type, version string) string { + url := strings.ReplaceAll(pkgDownloadURL, "%version", version) + return strings.ReplaceAll(url, "%arch", runtime.GOARCH) +} diff --git a/client/internal/updatemanager/installer/installer_run_windows.go b/client/internal/updatemanager/installer/installer_run_windows.go new file mode 100644 index 000000000..70c7e32cf --- /dev/null +++ b/client/internal/updatemanager/installer/installer_run_windows.go @@ -0,0 +1,213 @@ +package installer + +import ( + "context" + "fmt" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "time" + "unsafe" + + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" +) + +const ( + daemonName = "netbird.exe" + uiName = "netbird-ui.exe" + updaterBinary = "updater.exe" + + msiLogFile = "msi.log" + + msiDownloadURL = "https://github.com/netbirdio/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.msi" + exeDownloadURL = "https://github.com/netbirdio/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.exe" +) + +var ( + defaultTempDir = filepath.Join(os.Getenv("ProgramData"), "Netbird", "tmp-install") + + // for the cleanup + binaryExtensions = []string{"msi", "exe"} +) + +// Setup runs the installer with appropriate arguments and manages the daemon/UI state +// This will be run by the updater process +func (u *Installer) Setup(ctx context.Context, dryRun bool, installerFile string, daemonFolder string) (resultErr error) { + resultHandler := NewResultHandler(u.tempDir) + + // Always ensure daemon and UI are restarted after setup + defer func() { + log.Infof("starting daemon back") + if err := u.startDaemon(daemonFolder); err != nil { + log.Errorf("failed to start daemon: %v", err) + } + + log.Infof("starting UI back") + if err := u.startUIAsUser(daemonFolder); err != nil { + log.Errorf("failed to start UI: %v", err) + } + + log.Infof("write out result") + var err error + if resultErr == nil { + err = resultHandler.WriteSuccess() + } else { + err = resultHandler.WriteErr(resultErr) + } + if err != nil { + log.Errorf("failed to write update result: %v", err) + } + }() + + if dryRun { + log.Infof("dry-run mode enabled, skipping actual installation") + resultErr = fmt.Errorf("dry-run mode enabled") + return + } + + installerType, err := typeByFileExtension(installerFile) + if err != nil { + log.Debugf("%v", err) + resultErr = err + return + } + + var cmd *exec.Cmd + switch installerType { + case TypeExe: + log.Infof("run exe installer: %s", installerFile) + cmd = exec.CommandContext(ctx, installerFile, "/S") + default: + installerDir := filepath.Dir(installerFile) + logPath := filepath.Join(installerDir, msiLogFile) + log.Infof("run msi installer: %s", installerFile) + cmd = exec.CommandContext(ctx, "msiexec.exe", "/i", filepath.Base(installerFile), "/quiet", "/qn", "/l*v", logPath) + } + + cmd.Dir = filepath.Dir(installerFile) + + if resultErr = cmd.Start(); resultErr != nil { + log.Errorf("error starting installer: %v", resultErr) + return + } + + log.Infof("installer started with PID %d", cmd.Process.Pid) + if resultErr = cmd.Wait(); resultErr != nil { + log.Errorf("installer process finished with error: %v", resultErr) + return + } + + return nil +} + +func (u *Installer) startDaemon(daemonFolder string) error { + log.Infof("starting netbird service") + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, filepath.Join(daemonFolder, daemonName), "service", "start") + if output, err := cmd.CombinedOutput(); err != nil { + log.Debugf("failed to start netbird service: %v, output: %s", err, string(output)) + return err + } + log.Infof("netbird service started successfully") + return nil +} + +func (u *Installer) startUIAsUser(daemonFolder string) error { + uiPath := filepath.Join(daemonFolder, uiName) + log.Infof("starting netbird-ui: %s", uiPath) + + // Get the active console session ID + sessionID := windows.WTSGetActiveConsoleSessionId() + if sessionID == 0xFFFFFFFF { + return fmt.Errorf("no active user session found") + } + + // Get the user token for that session + var userToken windows.Token + err := windows.WTSQueryUserToken(sessionID, &userToken) + if err != nil { + return fmt.Errorf("failed to query user token: %w", err) + } + defer func() { + if err := userToken.Close(); err != nil { + log.Warnf("failed to close user token: %v", err) + } + }() + + // Duplicate the token to a primary token + var primaryToken windows.Token + err = windows.DuplicateTokenEx( + userToken, + windows.MAXIMUM_ALLOWED, + nil, + windows.SecurityImpersonation, + windows.TokenPrimary, + &primaryToken, + ) + if err != nil { + return fmt.Errorf("failed to duplicate token: %w", err) + } + defer func() { + if err := primaryToken.Close(); err != nil { + log.Warnf("failed to close token: %v", err) + } + }() + + // Prepare startup info + var si windows.StartupInfo + si.Cb = uint32(unsafe.Sizeof(si)) + si.Desktop = windows.StringToUTF16Ptr("winsta0\\default") + + var pi windows.ProcessInformation + + cmdLine, err := windows.UTF16PtrFromString(fmt.Sprintf("\"%s\"", uiPath)) + if err != nil { + return fmt.Errorf("failed to convert path to UTF16: %w", err) + } + + creationFlags := uint32(0x00000200 | 0x00000008 | 0x00000400) // CREATE_NEW_PROCESS_GROUP | DETACHED_PROCESS | CREATE_UNICODE_ENVIRONMENT + + err = windows.CreateProcessAsUser( + primaryToken, + nil, + cmdLine, + nil, + nil, + false, + creationFlags, + nil, + nil, + &si, + &pi, + ) + if err != nil { + return fmt.Errorf("CreateProcessAsUser failed: %w", err) + } + + // Close handles + if err := windows.CloseHandle(pi.Process); err != nil { + log.Warnf("failed to close process handle: %v", err) + } + if err := windows.CloseHandle(pi.Thread); err != nil { + log.Warnf("failed to close thread handle: %v", err) + } + + log.Infof("netbird-ui started successfully in session %d", sessionID) + return nil +} + +func urlWithVersionArch(it Type, version string) string { + var url string + if it == TypeExe { + url = exeDownloadURL + } else { + url = msiDownloadURL + } + url = strings.ReplaceAll(url, "%version", version) + return strings.ReplaceAll(url, "%arch", runtime.GOARCH) +} diff --git a/client/internal/updatemanager/installer/log.go b/client/internal/updatemanager/installer/log.go new file mode 100644 index 000000000..8b60dba28 --- /dev/null +++ b/client/internal/updatemanager/installer/log.go @@ -0,0 +1,5 @@ +package installer + +const ( + LogFile = "installer.log" +) diff --git a/client/internal/updatemanager/installer/procattr_darwin.go b/client/internal/updatemanager/installer/procattr_darwin.go new file mode 100644 index 000000000..56f2018bb --- /dev/null +++ b/client/internal/updatemanager/installer/procattr_darwin.go @@ -0,0 +1,15 @@ +package installer + +import ( + "os/exec" + "syscall" +) + +// setUpdaterProcAttr configures the updater process to run in a new session, +// making it independent of the parent daemon process. This ensures the updater +// survives when the daemon is stopped during the pkg installation. +func setUpdaterProcAttr(cmd *exec.Cmd) { + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setsid: true, + } +} diff --git a/client/internal/updatemanager/installer/procattr_windows.go b/client/internal/updatemanager/installer/procattr_windows.go new file mode 100644 index 000000000..29a8a2de0 --- /dev/null +++ b/client/internal/updatemanager/installer/procattr_windows.go @@ -0,0 +1,14 @@ +package installer + +import ( + "os/exec" + "syscall" +) + +// setUpdaterProcAttr configures the updater process to run detached from the parent, +// making it independent of the parent daemon process. +func setUpdaterProcAttr(cmd *exec.Cmd) { + cmd.SysProcAttr = &syscall.SysProcAttr{ + CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP | 0x00000008, // 0x00000008 is DETACHED_PROCESS + } +} diff --git a/client/internal/updatemanager/installer/repourl_dev.go b/client/internal/updatemanager/installer/repourl_dev.go new file mode 100644 index 000000000..088821ad3 --- /dev/null +++ b/client/internal/updatemanager/installer/repourl_dev.go @@ -0,0 +1,7 @@ +//go:build devartifactsign + +package installer + +const ( + DefaultSigningKeysBaseURL = "http://192.168.0.10:9089/signrepo" +) diff --git a/client/internal/updatemanager/installer/repourl_prod.go b/client/internal/updatemanager/installer/repourl_prod.go new file mode 100644 index 000000000..abddc62c1 --- /dev/null +++ b/client/internal/updatemanager/installer/repourl_prod.go @@ -0,0 +1,7 @@ +//go:build !devartifactsign + +package installer + +const ( + DefaultSigningKeysBaseURL = "https://publickeys.netbird.io/artifact-signatures" +) diff --git a/client/internal/updatemanager/installer/result.go b/client/internal/updatemanager/installer/result.go new file mode 100644 index 000000000..03d08d527 --- /dev/null +++ b/client/internal/updatemanager/installer/result.go @@ -0,0 +1,230 @@ +package installer + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/fsnotify/fsnotify" + log "github.com/sirupsen/logrus" +) + +const ( + resultFile = "result.json" +) + +type Result struct { + Success bool + Error string + ExecutedAt time.Time +} + +// ResultHandler handles reading and writing update results +type ResultHandler struct { + resultFile string +} + +// NewResultHandler creates a new communicator with the given directory path +// The result file will be created as "result.json" in the specified directory +func NewResultHandler(installerDir string) *ResultHandler { + // Create it if it doesn't exist + // do not care if already exists + _ = os.MkdirAll(installerDir, 0o700) + + rh := &ResultHandler{ + resultFile: filepath.Join(installerDir, resultFile), + } + return rh +} + +func (rh *ResultHandler) GetErrorResultReason() string { + result, err := rh.tryReadResult() + if err == nil && !result.Success { + return result.Error + } + + if err := rh.cleanup(); err != nil { + log.Warnf("failed to cleanup result file: %v", err) + } + + return "" +} + +func (rh *ResultHandler) WriteSuccess() error { + result := Result{ + Success: true, + ExecutedAt: time.Now(), + } + return rh.write(result) +} + +func (rh *ResultHandler) WriteErr(errReason error) error { + result := Result{ + Success: false, + Error: errReason.Error(), + ExecutedAt: time.Now(), + } + return rh.write(result) +} + +func (rh *ResultHandler) Watch(ctx context.Context) (Result, error) { + log.Infof("start watching result: %s", rh.resultFile) + + // Check if file already exists (updater finished before we started watching) + if result, err := rh.tryReadResult(); err == nil { + log.Infof("installer result: %v", result) + return result, nil + } + + dir := filepath.Dir(rh.resultFile) + + if err := rh.waitForDirectory(ctx, dir); err != nil { + return Result{}, err + } + + return rh.watchForResultFile(ctx, dir) +} + +func (rh *ResultHandler) waitForDirectory(ctx context.Context, dir string) error { + ticker := time.NewTicker(300 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + if info, err := os.Stat(dir); err == nil && info.IsDir() { + return nil + } + } + } +} + +func (rh *ResultHandler) watchForResultFile(ctx context.Context, dir string) (Result, error) { + watcher, err := fsnotify.NewWatcher() + if err != nil { + log.Error(err) + return Result{}, err + } + + defer func() { + if err := watcher.Close(); err != nil { + log.Warnf("failed to close watcher: %v", err) + } + }() + + if err := watcher.Add(dir); err != nil { + return Result{}, fmt.Errorf("failed to watch directory: %v", err) + } + + // Check again after setting up watcher to avoid race condition + // (file could have been created between initial check and watcher setup) + if result, err := rh.tryReadResult(); err == nil { + log.Infof("installer result: %v", result) + return result, nil + } + + for { + select { + case <-ctx.Done(): + return Result{}, ctx.Err() + case event, ok := <-watcher.Events: + if !ok { + return Result{}, errors.New("watcher closed unexpectedly") + } + + if result, done := rh.handleWatchEvent(event); done { + return result, nil + } + case err, ok := <-watcher.Errors: + if !ok { + return Result{}, errors.New("watcher closed unexpectedly") + } + return Result{}, fmt.Errorf("watcher error: %w", err) + } + } +} + +func (rh *ResultHandler) handleWatchEvent(event fsnotify.Event) (Result, bool) { + if event.Name != rh.resultFile { + return Result{}, false + } + + if event.Has(fsnotify.Create) { + result, err := rh.tryReadResult() + if err != nil { + log.Debugf("error while reading result: %v", err) + return result, true + } + log.Infof("installer result: %v", result) + return result, true + } + + return Result{}, false +} + +// Write writes the update result to a file for the UI to read +func (rh *ResultHandler) write(result Result) error { + log.Infof("write out installer result to: %s", rh.resultFile) + // Ensure directory exists + dir := filepath.Dir(rh.resultFile) + if err := os.MkdirAll(dir, 0o755); err != nil { + log.Errorf("failed to create directory %s: %v", dir, err) + return err + } + + data, err := json.Marshal(result) + if err != nil { + return err + } + + // Write to a temporary file first, then rename for atomic operation + tmpPath := rh.resultFile + ".tmp" + if err := os.WriteFile(tmpPath, data, 0o600); err != nil { + log.Errorf("failed to create temp file: %s", err) + return err + } + + // Atomic rename + if err := os.Rename(tmpPath, rh.resultFile); err != nil { + if cleanupErr := os.Remove(tmpPath); cleanupErr != nil { + log.Warnf("Failed to remove temp result file: %v", err) + } + return err + } + + return nil +} + +func (rh *ResultHandler) cleanup() error { + err := os.Remove(rh.resultFile) + if err != nil && !os.IsNotExist(err) { + return err + } + log.Debugf("delete installer result file: %s", rh.resultFile) + return nil +} + +// tryReadResult attempts to read and validate the result file +func (rh *ResultHandler) tryReadResult() (Result, error) { + data, err := os.ReadFile(rh.resultFile) + if err != nil { + return Result{}, err + } + + var result Result + if err := json.Unmarshal(data, &result); err != nil { + return Result{}, fmt.Errorf("invalid result format: %w", err) + } + + if err := rh.cleanup(); err != nil { + log.Warnf("failed to cleanup result file: %v", err) + } + + return result, nil +} diff --git a/client/internal/updatemanager/installer/types.go b/client/internal/updatemanager/installer/types.go new file mode 100644 index 000000000..656d84f88 --- /dev/null +++ b/client/internal/updatemanager/installer/types.go @@ -0,0 +1,14 @@ +package installer + +type Type struct { + name string + downloadable bool +} + +func (t Type) String() string { + return t.name +} + +func (t Type) Downloadable() bool { + return t.downloadable +} diff --git a/client/internal/updatemanager/installer/types_darwin.go b/client/internal/updatemanager/installer/types_darwin.go new file mode 100644 index 000000000..95a0cb737 --- /dev/null +++ b/client/internal/updatemanager/installer/types_darwin.go @@ -0,0 +1,22 @@ +package installer + +import ( + "context" + "os/exec" +) + +var ( + TypeHomebrew = Type{name: "Homebrew", downloadable: false} + TypePKG = Type{name: "pkg", downloadable: true} +) + +func TypeOfInstaller(ctx context.Context) Type { + cmd := exec.CommandContext(ctx, "pkgutil", "--pkg-info", "io.netbird.client") + _, err := cmd.Output() + if err != nil && cmd.ProcessState.ExitCode() == 1 { + // Not installed using pkg file, thus installed using Homebrew + + return TypeHomebrew + } + return TypePKG +} diff --git a/client/internal/updatemanager/installer/types_windows.go b/client/internal/updatemanager/installer/types_windows.go new file mode 100644 index 000000000..d4e5d83bd --- /dev/null +++ b/client/internal/updatemanager/installer/types_windows.go @@ -0,0 +1,51 @@ +package installer + +import ( + "context" + "fmt" + "strings" + + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows/registry" +) + +const ( + uninstallKeyPath64 = `SOFTWARE\WOW6432Node\Microsoft\Windows\CurrentVersion\Uninstall\Netbird` + uninstallKeyPath32 = `SOFTWARE\Microsoft\Windows\CurrentVersion\Uninstall\Netbird` +) + +var ( + TypeExe = Type{name: "EXE", downloadable: true} + TypeMSI = Type{name: "MSI", downloadable: true} +) + +func TypeOfInstaller(_ context.Context) Type { + paths := []string{uninstallKeyPath64, uninstallKeyPath32} + + for _, path := range paths { + k, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE) + if err != nil { + continue + } + + if err := k.Close(); err != nil { + log.Warnf("Error closing registry key: %v", err) + } + return TypeExe + + } + + log.Debug("No registry entry found for Netbird, assuming MSI installation") + return TypeMSI +} + +func typeByFileExtension(filePath string) (Type, error) { + switch { + case strings.HasSuffix(strings.ToLower(filePath), ".exe"): + return TypeExe, nil + case strings.HasSuffix(strings.ToLower(filePath), ".msi"): + return TypeMSI, nil + default: + return Type{}, fmt.Errorf("unsupported installer type for file: %s", filePath) + } +} diff --git a/client/internal/updatemanager/manager.go b/client/internal/updatemanager/manager.go new file mode 100644 index 000000000..eae11de56 --- /dev/null +++ b/client/internal/updatemanager/manager.go @@ -0,0 +1,374 @@ +//go:build windows || darwin + +package updatemanager + +import ( + "context" + "errors" + "fmt" + "runtime" + "sync" + "time" + + v "github.com/hashicorp/go-version" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/client/internal/updatemanager/installer" + cProto "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/version" +) + +const ( + latestVersion = "latest" + // this version will be ignored + developmentVersion = "development" +) + +var errNoUpdateState = errors.New("no update state found") + +type UpdateState struct { + PreUpdateVersion string + TargetVersion string +} + +func (u UpdateState) Name() string { + return "autoUpdate" +} + +type Manager struct { + statusRecorder *peer.Status + stateManager *statemanager.Manager + + lastTrigger time.Time + mgmUpdateChan chan struct{} + updateChannel chan struct{} + currentVersion string + update UpdateInterface + wg sync.WaitGroup + + cancel context.CancelFunc + + expectedVersion *v.Version + updateToLatestVersion bool + + // updateMutex protect update and expectedVersion fields + updateMutex sync.Mutex + + triggerUpdateFn func(context.Context, string) error +} + +func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) { + if runtime.GOOS == "darwin" { + isBrew := !installer.TypeOfInstaller(context.Background()).Downloadable() + if isBrew { + log.Warnf("auto-update disabled on Home Brew installation") + return nil, fmt.Errorf("auto-update not supported on Home Brew installation yet") + } + } + return newManager(statusRecorder, stateManager) +} + +func newManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) { + manager := &Manager{ + statusRecorder: statusRecorder, + stateManager: stateManager, + mgmUpdateChan: make(chan struct{}, 1), + updateChannel: make(chan struct{}, 1), + currentVersion: version.NetbirdVersion(), + update: version.NewUpdate("nb/client"), + } + manager.triggerUpdateFn = manager.triggerUpdate + + stateManager.RegisterState(&UpdateState{}) + + return manager, nil +} + +// CheckUpdateSuccess checks if the update was successful and send a notification. +// It works without to start the update manager. +func (m *Manager) CheckUpdateSuccess(ctx context.Context) { + reason := m.lastResultErrReason() + if reason != "" { + m.statusRecorder.PublishEvent( + cProto.SystemEvent_ERROR, + cProto.SystemEvent_SYSTEM, + "Auto-update failed", + fmt.Sprintf("Auto-update failed: %s", reason), + nil, + ) + } + + updateState, err := m.loadAndDeleteUpdateState(ctx) + if err != nil { + if errors.Is(err, errNoUpdateState) { + return + } + log.Errorf("failed to load update state: %v", err) + return + } + + log.Debugf("auto-update state loaded, %v", *updateState) + + if updateState.TargetVersion == m.currentVersion { + m.statusRecorder.PublishEvent( + cProto.SystemEvent_INFO, + cProto.SystemEvent_SYSTEM, + "Auto-update completed", + fmt.Sprintf("Your NetBird Client was auto-updated to version %s", m.currentVersion), + nil, + ) + return + } +} + +func (m *Manager) Start(ctx context.Context) { + if m.cancel != nil { + log.Errorf("Manager already started") + return + } + + m.update.SetDaemonVersion(m.currentVersion) + m.update.SetOnUpdateListener(func() { + select { + case m.updateChannel <- struct{}{}: + default: + } + }) + go m.update.StartFetcher() + + ctx, cancel := context.WithCancel(ctx) + m.cancel = cancel + + m.wg.Add(1) + go m.updateLoop(ctx) +} + +func (m *Manager) SetVersion(expectedVersion string) { + log.Infof("set expected agent version for upgrade: %s", expectedVersion) + if m.cancel == nil { + log.Errorf("manager not started") + return + } + + m.updateMutex.Lock() + defer m.updateMutex.Unlock() + + if expectedVersion == "" { + log.Errorf("empty expected version provided") + m.expectedVersion = nil + m.updateToLatestVersion = false + return + } + + if expectedVersion == latestVersion { + m.updateToLatestVersion = true + m.expectedVersion = nil + } else { + expectedSemVer, err := v.NewVersion(expectedVersion) + if err != nil { + log.Errorf("error parsing version: %v", err) + return + } + if m.expectedVersion != nil && m.expectedVersion.Equal(expectedSemVer) { + return + } + m.expectedVersion = expectedSemVer + m.updateToLatestVersion = false + } + + select { + case m.mgmUpdateChan <- struct{}{}: + default: + } +} + +func (m *Manager) Stop() { + if m.cancel == nil { + return + } + + m.cancel() + m.updateMutex.Lock() + if m.update != nil { + m.update.StopWatch() + m.update = nil + } + m.updateMutex.Unlock() + + m.wg.Wait() +} + +func (m *Manager) onContextCancel() { + if m.cancel == nil { + return + } + + m.updateMutex.Lock() + defer m.updateMutex.Unlock() + if m.update != nil { + m.update.StopWatch() + m.update = nil + } +} + +func (m *Manager) updateLoop(ctx context.Context) { + defer m.wg.Done() + + for { + select { + case <-ctx.Done(): + m.onContextCancel() + return + case <-m.mgmUpdateChan: + case <-m.updateChannel: + log.Infof("fetched new version info") + } + + m.handleUpdate(ctx) + } +} + +func (m *Manager) handleUpdate(ctx context.Context) { + var updateVersion *v.Version + + m.updateMutex.Lock() + if m.update == nil { + m.updateMutex.Unlock() + return + } + + expectedVersion := m.expectedVersion + useLatest := m.updateToLatestVersion + curLatestVersion := m.update.LatestVersion() + m.updateMutex.Unlock() + + switch { + // Resolve "latest" to actual version + case useLatest: + if curLatestVersion == nil { + log.Tracef("latest version not fetched yet") + return + } + updateVersion = curLatestVersion + // Update to specific version + case expectedVersion != nil: + updateVersion = expectedVersion + default: + log.Debugf("no expected version information set") + return + } + + log.Debugf("checking update option, current version: %s, target version: %s", m.currentVersion, updateVersion) + if !m.shouldUpdate(updateVersion) { + return + } + + m.lastTrigger = time.Now() + log.Infof("Auto-update triggered, current version: %s, target version: %s", m.currentVersion, updateVersion) + m.statusRecorder.PublishEvent( + cProto.SystemEvent_CRITICAL, + cProto.SystemEvent_SYSTEM, + "Automatically updating client", + "Your client version is older than auto-update version set in Management, updating client now.", + nil, + ) + + m.statusRecorder.PublishEvent( + cProto.SystemEvent_CRITICAL, + cProto.SystemEvent_SYSTEM, + "", + "", + map[string]string{"progress_window": "show", "version": updateVersion.String()}, + ) + + updateState := UpdateState{ + PreUpdateVersion: m.currentVersion, + TargetVersion: updateVersion.String(), + } + + if err := m.stateManager.UpdateState(updateState); err != nil { + log.Warnf("failed to update state: %v", err) + } else { + if err = m.stateManager.PersistState(ctx); err != nil { + log.Warnf("failed to persist state: %v", err) + } + } + + if err := m.triggerUpdateFn(ctx, updateVersion.String()); err != nil { + log.Errorf("Error triggering auto-update: %v", err) + m.statusRecorder.PublishEvent( + cProto.SystemEvent_ERROR, + cProto.SystemEvent_SYSTEM, + "Auto-update failed", + fmt.Sprintf("Auto-update failed: %v", err), + nil, + ) + } +} + +// loadAndDeleteUpdateState loads the update state, deletes it from storage, and returns it. +// Returns nil if no state exists. +func (m *Manager) loadAndDeleteUpdateState(ctx context.Context) (*UpdateState, error) { + stateType := &UpdateState{} + + m.stateManager.RegisterState(stateType) + if err := m.stateManager.LoadState(stateType); err != nil { + return nil, fmt.Errorf("load state: %w", err) + } + + state := m.stateManager.GetState(stateType) + if state == nil { + return nil, errNoUpdateState + } + + updateState, ok := state.(*UpdateState) + if !ok { + return nil, fmt.Errorf("failed to cast state to UpdateState") + } + + if err := m.stateManager.DeleteState(updateState); err != nil { + return nil, fmt.Errorf("delete state: %w", err) + } + + if err := m.stateManager.PersistState(ctx); err != nil { + return nil, fmt.Errorf("persist state: %w", err) + } + + return updateState, nil +} + +func (m *Manager) shouldUpdate(updateVersion *v.Version) bool { + if m.currentVersion == developmentVersion { + log.Debugf("skipping auto-update, running development version") + return false + } + currentVersion, err := v.NewVersion(m.currentVersion) + if err != nil { + log.Errorf("error checking for update, error parsing version `%s`: %v", m.currentVersion, err) + return false + } + if currentVersion.GreaterThanOrEqual(updateVersion) { + log.Infof("current version (%s) is equal to or higher than auto-update version (%s)", m.currentVersion, updateVersion) + return false + } + + if time.Since(m.lastTrigger) < 5*time.Minute { + log.Debugf("skipping auto-update, last update was %s ago", time.Since(m.lastTrigger)) + return false + } + + return true +} + +func (m *Manager) lastResultErrReason() string { + inst := installer.New() + result := installer.NewResultHandler(inst.TempDir()) + return result.GetErrorResultReason() +} + +func (m *Manager) triggerUpdate(ctx context.Context, targetVersion string) error { + inst := installer.New() + return inst.RunInstallation(ctx, targetVersion) +} diff --git a/client/internal/updatemanager/manager_test.go b/client/internal/updatemanager/manager_test.go new file mode 100644 index 000000000..20ddec10d --- /dev/null +++ b/client/internal/updatemanager/manager_test.go @@ -0,0 +1,214 @@ +//go:build windows || darwin + +package updatemanager + +import ( + "context" + "fmt" + "path" + "testing" + "time" + + v "github.com/hashicorp/go-version" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +type versionUpdateMock struct { + latestVersion *v.Version + onUpdate func() +} + +func (v versionUpdateMock) StopWatch() {} + +func (v versionUpdateMock) SetDaemonVersion(newVersion string) bool { + return false +} + +func (v *versionUpdateMock) SetOnUpdateListener(updateFn func()) { + v.onUpdate = updateFn +} + +func (v versionUpdateMock) LatestVersion() *v.Version { + return v.latestVersion +} + +func (v versionUpdateMock) StartFetcher() {} + +func Test_LatestVersion(t *testing.T) { + testMatrix := []struct { + name string + daemonVersion string + initialLatestVersion *v.Version + latestVersion *v.Version + shouldUpdateInit bool + shouldUpdateLater bool + }{ + { + name: "Should only trigger update once due to time between triggers being < 5 Minutes", + daemonVersion: "1.0.0", + initialLatestVersion: v.Must(v.NewSemver("1.0.1")), + latestVersion: v.Must(v.NewSemver("1.0.2")), + shouldUpdateInit: true, + shouldUpdateLater: false, + }, + { + name: "Shouldn't update initially, but should update as soon as latest version is fetched", + daemonVersion: "1.0.0", + initialLatestVersion: nil, + latestVersion: v.Must(v.NewSemver("1.0.1")), + shouldUpdateInit: false, + shouldUpdateLater: true, + }, + } + + for idx, c := range testMatrix { + mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion} + tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx)) + m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile)) + m.update = mockUpdate + + targetVersionChan := make(chan string, 1) + + m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error { + targetVersionChan <- targetVersion + return nil + } + m.currentVersion = c.daemonVersion + m.Start(context.Background()) + m.SetVersion("latest") + var triggeredInit bool + select { + case targetVersion := <-targetVersionChan: + if targetVersion != c.initialLatestVersion.String() { + t.Errorf("%s: Initial update version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), targetVersion) + } + triggeredInit = true + case <-time.After(10 * time.Millisecond): + triggeredInit = false + } + if triggeredInit != c.shouldUpdateInit { + t.Errorf("%s: Initial update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit) + } + + mockUpdate.latestVersion = c.latestVersion + mockUpdate.onUpdate() + + var triggeredLater bool + select { + case targetVersion := <-targetVersionChan: + if targetVersion != c.latestVersion.String() { + t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion) + } + triggeredLater = true + case <-time.After(10 * time.Millisecond): + triggeredLater = false + } + if triggeredLater != c.shouldUpdateLater { + t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater) + } + + m.Stop() + } +} + +func Test_HandleUpdate(t *testing.T) { + testMatrix := []struct { + name string + daemonVersion string + latestVersion *v.Version + expectedVersion string + shouldUpdate bool + }{ + { + name: "Update to a specific version should update regardless of if latestVersion is available yet", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "0.56.0", + shouldUpdate: true, + }, + { + name: "Update to specific version should not update if version matches", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "0.55.0", + shouldUpdate: false, + }, + { + name: "Update to specific version should not update if current version is newer", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "0.54.0", + shouldUpdate: false, + }, + { + name: "Update to latest version should update if latest is newer", + daemonVersion: "0.55.0", + latestVersion: v.Must(v.NewSemver("0.56.0")), + expectedVersion: "latest", + shouldUpdate: true, + }, + { + name: "Update to latest version should not update if latest == current", + daemonVersion: "0.56.0", + latestVersion: v.Must(v.NewSemver("0.56.0")), + expectedVersion: "latest", + shouldUpdate: false, + }, + { + name: "Should not update if daemon version is invalid", + daemonVersion: "development", + latestVersion: v.Must(v.NewSemver("1.0.0")), + expectedVersion: "latest", + shouldUpdate: false, + }, + { + name: "Should not update if expecting latest and latest version is unavailable", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "latest", + shouldUpdate: false, + }, + { + name: "Should not update if expected version is invalid", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "development", + shouldUpdate: false, + }, + } + for idx, c := range testMatrix { + tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx)) + m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile)) + m.update = &versionUpdateMock{latestVersion: c.latestVersion} + targetVersionChan := make(chan string, 1) + + m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error { + targetVersionChan <- targetVersion + return nil + } + + m.currentVersion = c.daemonVersion + m.Start(context.Background()) + m.SetVersion(c.expectedVersion) + + var updateTriggered bool + select { + case targetVersion := <-targetVersionChan: + if c.expectedVersion == "latest" && targetVersion != c.latestVersion.String() { + t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion) + } else if c.expectedVersion != "latest" && targetVersion != c.expectedVersion { + t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.expectedVersion, targetVersion) + } + updateTriggered = true + case <-time.After(10 * time.Millisecond): + updateTriggered = false + } + + if updateTriggered != c.shouldUpdate { + t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdate, updateTriggered) + } + m.Stop() + } +} diff --git a/client/internal/updatemanager/manager_unsupported.go b/client/internal/updatemanager/manager_unsupported.go new file mode 100644 index 000000000..4e87c2d77 --- /dev/null +++ b/client/internal/updatemanager/manager_unsupported.go @@ -0,0 +1,39 @@ +//go:build !windows && !darwin + +package updatemanager + +import ( + "context" + "fmt" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +// Manager is a no-op stub for unsupported platforms +type Manager struct{} + +// NewManager returns a no-op manager for unsupported platforms +func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) { + return nil, fmt.Errorf("update manager is not supported on this platform") +} + +// CheckUpdateSuccess is a no-op on unsupported platforms +func (m *Manager) CheckUpdateSuccess(ctx context.Context) { + // no-op +} + +// Start is a no-op on unsupported platforms +func (m *Manager) Start(ctx context.Context) { + // no-op +} + +// SetVersion is a no-op on unsupported platforms +func (m *Manager) SetVersion(expectedVersion string) { + // no-op +} + +// Stop is a no-op on unsupported platforms +func (m *Manager) Stop() { + // no-op +} diff --git a/client/internal/updatemanager/reposign/artifact.go b/client/internal/updatemanager/reposign/artifact.go new file mode 100644 index 000000000..3d4fe9c74 --- /dev/null +++ b/client/internal/updatemanager/reposign/artifact.go @@ -0,0 +1,302 @@ +package reposign + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/binary" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "hash" + "time" + + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/blake2s" +) + +const ( + tagArtifactPrivate = "ARTIFACT PRIVATE KEY" + tagArtifactPublic = "ARTIFACT PUBLIC KEY" + + maxArtifactKeySignatureAge = 10 * 365 * 24 * time.Hour + maxArtifactSignatureAge = 10 * 365 * 24 * time.Hour +) + +// ArtifactHash wraps a hash.Hash and counts bytes written +type ArtifactHash struct { + hash.Hash +} + +// NewArtifactHash returns an initialized ArtifactHash using BLAKE2s +func NewArtifactHash() *ArtifactHash { + h, err := blake2s.New256(nil) + if err != nil { + panic(err) // Should never happen with nil Key + } + return &ArtifactHash{Hash: h} +} + +func (ah *ArtifactHash) Write(b []byte) (int, error) { + return ah.Hash.Write(b) +} + +// ArtifactKey is a signing Key used to sign artifacts +type ArtifactKey struct { + PrivateKey +} + +func (k ArtifactKey) String() string { + return fmt.Sprintf( + "ArtifactKey[ID=%s, CreatedAt=%s, ExpiresAt=%s]", + k.Metadata.ID, + k.Metadata.CreatedAt.Format(time.RFC3339), + k.Metadata.ExpiresAt.Format(time.RFC3339), + ) +} + +func GenerateArtifactKey(rootKey *RootKey, expiration time.Duration) (*ArtifactKey, []byte, []byte, []byte, error) { + // Verify root key is still valid + if !rootKey.Metadata.ExpiresAt.IsZero() && time.Now().After(rootKey.Metadata.ExpiresAt) { + return nil, nil, nil, nil, fmt.Errorf("root key has expired on %s", rootKey.Metadata.ExpiresAt.Format(time.RFC3339)) + } + + now := time.Now() + expirationTime := now.Add(expiration) + pub, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, nil, nil, fmt.Errorf("generate ed25519 key: %w", err) + } + + metadata := KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: now.UTC(), + ExpiresAt: expirationTime.UTC(), + } + + ak := &ArtifactKey{ + PrivateKey{ + Key: priv, + Metadata: metadata, + }, + } + + // Marshal PrivateKey struct to JSON + privJSON, err := json.Marshal(ak.PrivateKey) + if err != nil { + return nil, nil, nil, nil, fmt.Errorf("failed to marshal private key: %w", err) + } + + // Marshal PublicKey struct to JSON + pubKey := PublicKey{ + Key: pub, + Metadata: metadata, + } + pubJSON, err := json.Marshal(pubKey) + if err != nil { + return nil, nil, nil, nil, fmt.Errorf("failed to marshal public key: %w", err) + } + + // Encode to PEM with metadata embedded in bytes + privPEM := pem.EncodeToMemory(&pem.Block{ + Type: tagArtifactPrivate, + Bytes: privJSON, + }) + + pubPEM := pem.EncodeToMemory(&pem.Block{ + Type: tagArtifactPublic, + Bytes: pubJSON, + }) + + // Sign the public key with the root key + signature, err := SignArtifactKey(*rootKey, pubPEM) + if err != nil { + return nil, nil, nil, nil, fmt.Errorf("failed to sign artifact key: %w", err) + } + + return ak, privPEM, pubPEM, signature, nil +} + +func ParseArtifactKey(privKeyPEM []byte) (ArtifactKey, error) { + pk, err := parsePrivateKey(privKeyPEM, tagArtifactPrivate) + if err != nil { + return ArtifactKey{}, fmt.Errorf("failed to parse artifact Key: %w", err) + } + return ArtifactKey{pk}, nil +} + +func ParseArtifactPubKey(data []byte) (PublicKey, error) { + pk, _, err := parsePublicKey(data, tagArtifactPublic) + return pk, err +} + +func BundleArtifactKeys(rootKey *RootKey, keys []PublicKey) ([]byte, []byte, error) { + if len(keys) == 0 { + return nil, nil, errors.New("no keys to bundle") + } + + // Create bundle by concatenating PEM-encoded keys + var pubBundle []byte + + for _, pk := range keys { + // Marshal PublicKey struct to JSON + pubJSON, err := json.Marshal(pk) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal public key: %w", err) + } + + // Encode to PEM + pubPEM := pem.EncodeToMemory(&pem.Block{ + Type: tagArtifactPublic, + Bytes: pubJSON, + }) + + pubBundle = append(pubBundle, pubPEM...) + } + + // Sign the entire bundle with the root key + signature, err := SignArtifactKey(*rootKey, pubBundle) + if err != nil { + return nil, nil, fmt.Errorf("failed to sign artifact key bundle: %w", err) + } + + return pubBundle, signature, nil +} + +func ValidateArtifactKeys(publicRootKeys []PublicKey, data []byte, signature Signature, revocationList *RevocationList) ([]PublicKey, error) { + now := time.Now().UTC() + if signature.Timestamp.After(now.Add(maxClockSkew)) { + err := fmt.Errorf("signature timestamp is in the future: %v", signature.Timestamp) + log.Debugf("artifact signature error: %v", err) + return nil, err + } + if now.Sub(signature.Timestamp) > maxArtifactKeySignatureAge { + err := fmt.Errorf("signature is too old: %v (created %v)", now.Sub(signature.Timestamp), signature.Timestamp) + log.Debugf("artifact signature error: %v", err) + return nil, err + } + + // Reconstruct the signed message: artifact_key_data || timestamp + msg := make([]byte, 0, len(data)+8) + msg = append(msg, data...) + msg = binary.LittleEndian.AppendUint64(msg, uint64(signature.Timestamp.Unix())) + + if !verifyAny(publicRootKeys, msg, signature.Signature) { + return nil, errors.New("failed to verify signature of artifact keys") + } + + pubKeys, err := parsePublicKeyBundle(data, tagArtifactPublic) + if err != nil { + log.Debugf("failed to parse public keys: %s", err) + return nil, err + } + + validKeys := make([]PublicKey, 0, len(pubKeys)) + for _, pubKey := range pubKeys { + // Filter out expired keys + if !pubKey.Metadata.ExpiresAt.IsZero() && now.After(pubKey.Metadata.ExpiresAt) { + log.Debugf("Key %s is expired at %v (current time %v)", + pubKey.Metadata.ID, pubKey.Metadata.ExpiresAt, now) + continue + } + + if revocationList != nil { + if revTime, revoked := revocationList.Revoked[pubKey.Metadata.ID]; revoked { + log.Debugf("Key %s is revoked as of %v (created %v)", + pubKey.Metadata.ID, revTime, pubKey.Metadata.CreatedAt) + continue + } + } + validKeys = append(validKeys, pubKey) + } + + if len(validKeys) == 0 { + log.Debugf("no valid public keys found for artifact keys") + return nil, fmt.Errorf("all %d artifact keys are revoked", len(pubKeys)) + } + + return validKeys, nil +} + +func ValidateArtifact(artifactPubKeys []PublicKey, data []byte, signature Signature) error { + // Validate signature timestamp + now := time.Now().UTC() + if signature.Timestamp.After(now.Add(maxClockSkew)) { + err := fmt.Errorf("artifact signature timestamp is in the future: %v", signature.Timestamp) + log.Debugf("failed to verify signature of artifact: %s", err) + return err + } + if now.Sub(signature.Timestamp) > maxArtifactSignatureAge { + return fmt.Errorf("artifact signature is too old: %v (created %v)", + now.Sub(signature.Timestamp), signature.Timestamp) + } + + h := NewArtifactHash() + if _, err := h.Write(data); err != nil { + return fmt.Errorf("failed to hash artifact: %w", err) + } + hash := h.Sum(nil) + + // Reconstruct the signed message: hash || length || timestamp + msg := make([]byte, 0, len(hash)+8+8) + msg = append(msg, hash...) + msg = binary.LittleEndian.AppendUint64(msg, uint64(len(data))) + msg = binary.LittleEndian.AppendUint64(msg, uint64(signature.Timestamp.Unix())) + + // Find matching Key and verify + for _, keyInfo := range artifactPubKeys { + if keyInfo.Metadata.ID == signature.KeyID { + // Check Key expiration + if !keyInfo.Metadata.ExpiresAt.IsZero() && + signature.Timestamp.After(keyInfo.Metadata.ExpiresAt) { + return fmt.Errorf("signing Key %s expired at %v, signature from %v", + signature.KeyID, keyInfo.Metadata.ExpiresAt, signature.Timestamp) + } + + if ed25519.Verify(keyInfo.Key, msg, signature.Signature) { + log.Debugf("artifact verified successfully with Key: %s", signature.KeyID) + return nil + } + return fmt.Errorf("signature verification failed for Key %s", signature.KeyID) + } + } + + return fmt.Errorf("no signing Key found with ID %s", signature.KeyID) +} + +func SignData(artifactKey ArtifactKey, data []byte) ([]byte, error) { + if len(data) == 0 { // Check happens too late + return nil, fmt.Errorf("artifact length must be positive, got %d", len(data)) + } + + h := NewArtifactHash() + if _, err := h.Write(data); err != nil { + return nil, fmt.Errorf("failed to write artifact hash: %w", err) + } + + timestamp := time.Now().UTC() + + if !artifactKey.Metadata.ExpiresAt.IsZero() && timestamp.After(artifactKey.Metadata.ExpiresAt) { + return nil, fmt.Errorf("artifact key expired at %v", artifactKey.Metadata.ExpiresAt) + } + + hash := h.Sum(nil) + + // Create message: hash || length || timestamp + msg := make([]byte, 0, len(hash)+8+8) + msg = append(msg, hash...) + msg = binary.LittleEndian.AppendUint64(msg, uint64(len(data))) + msg = binary.LittleEndian.AppendUint64(msg, uint64(timestamp.Unix())) + + sig := ed25519.Sign(artifactKey.Key, msg) + + bundle := Signature{ + Signature: sig, + Timestamp: timestamp, + KeyID: artifactKey.Metadata.ID, + Algorithm: "ed25519", + HashAlgo: "blake2s", + } + + return json.Marshal(bundle) +} diff --git a/client/internal/updatemanager/reposign/artifact_test.go b/client/internal/updatemanager/reposign/artifact_test.go new file mode 100644 index 000000000..8865e2d0a --- /dev/null +++ b/client/internal/updatemanager/reposign/artifact_test.go @@ -0,0 +1,1080 @@ +package reposign + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/json" + "encoding/pem" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test ArtifactHash + +func TestNewArtifactHash(t *testing.T) { + h := NewArtifactHash() + assert.NotNil(t, h) + assert.NotNil(t, h.Hash) +} + +func TestArtifactHash_Write(t *testing.T) { + h := NewArtifactHash() + + data := []byte("test data") + n, err := h.Write(data) + require.NoError(t, err) + assert.Equal(t, len(data), n) + + hash := h.Sum(nil) + assert.NotEmpty(t, hash) + assert.Equal(t, 32, len(hash)) // BLAKE2s-256 +} + +func TestArtifactHash_Deterministic(t *testing.T) { + data := []byte("test data") + + h1 := NewArtifactHash() + if _, err := h1.Write(data); err != nil { + t.Fatal(err) + } + hash1 := h1.Sum(nil) + + h2 := NewArtifactHash() + if _, err := h2.Write(data); err != nil { + t.Fatal(err) + } + hash2 := h2.Sum(nil) + + assert.Equal(t, hash1, hash2) +} + +func TestArtifactHash_DifferentData(t *testing.T) { + h1 := NewArtifactHash() + if _, err := h1.Write([]byte("data1")); err != nil { + t.Fatal(err) + } + hash1 := h1.Sum(nil) + + h2 := NewArtifactHash() + if _, err := h2.Write([]byte("data2")); err != nil { + t.Fatal(err) + } + hash2 := h2.Sum(nil) + + assert.NotEqual(t, hash1, hash2) +} + +// Test ArtifactKey.String() + +func TestArtifactKey_String(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + createdAt := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + expiresAt := time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC) + + ak := ArtifactKey{ + PrivateKey{ + Key: priv, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: createdAt, + ExpiresAt: expiresAt, + }, + }, + } + + str := ak.String() + assert.Contains(t, str, "ArtifactKey") + assert.Contains(t, str, computeKeyID(pub).String()) + assert.Contains(t, str, "2024-01-15") + assert.Contains(t, str, "2025-01-15") +} + +// Test GenerateArtifactKey + +func TestGenerateArtifactKey_Valid(t *testing.T) { + // Create root key + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().Add(365 * 24 * time.Hour).UTC(), + }, + }, + } + + // Generate artifact key + ak, privPEM, pubPEM, signature, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + assert.NotNil(t, ak) + assert.NotEmpty(t, privPEM) + assert.NotEmpty(t, pubPEM) + assert.NotEmpty(t, signature) + + // Verify expiration + assert.True(t, ak.Metadata.ExpiresAt.After(time.Now())) + assert.True(t, ak.Metadata.ExpiresAt.Before(time.Now().Add(31*24*time.Hour))) +} + +func TestGenerateArtifactKey_ExpiredRoot(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + // Create expired root key + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().Add(-2 * 365 * 24 * time.Hour).UTC(), + ExpiresAt: time.Now().Add(-1 * time.Hour).UTC(), // Expired + }, + }, + } + + _, _, _, _, err = GenerateArtifactKey(rootKey, 30*24*time.Hour) + assert.Error(t, err) + assert.Contains(t, err.Error(), "expired") +} + +func TestGenerateArtifactKey_NoExpiration(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + // Root key with no expiration + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Time{}, // No expiration + }, + }, + } + + ak, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + assert.NotNil(t, ak) +} + +// Test ParseArtifactKey + +func TestParseArtifactKey_Valid(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + original, privPEM, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + // Parse it back + parsed, err := ParseArtifactKey(privPEM) + require.NoError(t, err) + + assert.Equal(t, original.Key, parsed.Key) + assert.Equal(t, original.Metadata.ID, parsed.Metadata.ID) +} + +func TestParseArtifactKey_InvalidPEM(t *testing.T) { + _, err := ParseArtifactKey([]byte("invalid pem")) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse") +} + +func TestParseArtifactKey_WrongType(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + // Create a root key (wrong type) + rootKey := RootKey{ + PrivateKey{ + Key: priv, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + privJSON, err := json.Marshal(rootKey.PrivateKey) + require.NoError(t, err) + + privPEM := encodePrivateKey(privJSON, tagRootPrivate) + + _, err = ParseArtifactKey(privPEM) + assert.Error(t, err) +} + +// Test ParseArtifactPubKey + +func TestParseArtifactPubKey_Valid(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + original, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + parsed, err := ParseArtifactPubKey(pubPEM) + require.NoError(t, err) + + assert.Equal(t, original.Metadata.ID, parsed.Metadata.ID) +} + +func TestParseArtifactPubKey_Invalid(t *testing.T) { + _, err := ParseArtifactPubKey([]byte("invalid")) + assert.Error(t, err) +} + +// Test BundleArtifactKeys + +func TestBundleArtifactKeys_Single(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + _, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + pubKey, err := ParseArtifactPubKey(pubPEM) + require.NoError(t, err) + + bundle, signature, err := BundleArtifactKeys(rootKey, []PublicKey{pubKey}) + require.NoError(t, err) + assert.NotEmpty(t, bundle) + assert.NotEmpty(t, signature) +} + +func TestBundleArtifactKeys_Multiple(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Generate 3 artifact keys + var pubKeys []PublicKey + for i := 0; i < 3; i++ { + _, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + pubKey, err := ParseArtifactPubKey(pubPEM) + require.NoError(t, err) + pubKeys = append(pubKeys, pubKey) + } + + bundle, signature, err := BundleArtifactKeys(rootKey, pubKeys) + require.NoError(t, err) + assert.NotEmpty(t, bundle) + assert.NotEmpty(t, signature) + + // Verify we can parse the bundle + parsed, err := parsePublicKeyBundle(bundle, tagArtifactPublic) + require.NoError(t, err) + assert.Len(t, parsed, 3) +} + +func TestBundleArtifactKeys_Empty(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + _, _, err = BundleArtifactKeys(rootKey, []PublicKey{}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no keys") +} + +// Test ValidateArtifactKeys + +func TestSingleValidateArtifactKey_Valid(t *testing.T) { + // Create root key + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Generate artifact key + _, _, pubPEM, sigData, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + sig, _ := ParseSignature(sigData) + + // Validate + validKeys, err := ValidateArtifactKeys(rootKeys, pubPEM, *sig, nil) + require.NoError(t, err) + assert.Len(t, validKeys, 1) +} + +func TestValidateArtifactKeys_Valid(t *testing.T) { + // Create root key + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Generate artifact key + _, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + pubKey, err := ParseArtifactPubKey(pubPEM) + require.NoError(t, err) + + // Bundle and sign + bundle, sigData, err := BundleArtifactKeys(rootKey, []PublicKey{pubKey}) + require.NoError(t, err) + + sig, err := ParseSignature(sigData) + require.NoError(t, err) + + // Validate + validKeys, err := ValidateArtifactKeys(rootKeys, bundle, *sig, nil) + require.NoError(t, err) + assert.Len(t, validKeys, 1) +} + +func TestValidateArtifactKeys_FutureTimestamp(t *testing.T) { + rootPub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + sig := Signature{ + Signature: make([]byte, 64), + Timestamp: time.Now().UTC().Add(10 * time.Minute), + KeyID: computeKeyID(rootPub), + Algorithm: "ed25519", + HashAlgo: "blake2s", + } + + _, err = ValidateArtifactKeys(rootKeys, []byte("data"), sig, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "in the future") +} + +func TestValidateArtifactKeys_TooOld(t *testing.T) { + rootPub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + sig := Signature{ + Signature: make([]byte, 64), + Timestamp: time.Now().UTC().Add(-20 * 365 * 24 * time.Hour), + KeyID: computeKeyID(rootPub), + Algorithm: "ed25519", + HashAlgo: "blake2s", + } + + _, err = ValidateArtifactKeys(rootKeys, []byte("data"), sig, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "too old") +} + +func TestValidateArtifactKeys_InvalidSignature(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + _, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + pubKey, err := ParseArtifactPubKey(pubPEM) + require.NoError(t, err) + + bundle, _, err := BundleArtifactKeys(rootKey, []PublicKey{pubKey}) + require.NoError(t, err) + + // Create invalid signature + invalidSig := Signature{ + Signature: make([]byte, 64), + Timestamp: time.Now().UTC(), + KeyID: computeKeyID(rootPub), + Algorithm: "ed25519", + HashAlgo: "blake2s", + } + + _, err = ValidateArtifactKeys(rootKeys, bundle, invalidSig, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to verify") +} + +func TestValidateArtifactKeys_WithRevocation(t *testing.T) { + // Create root key + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Generate two artifact keys + _, _, pubPEM1, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + pubKey1, err := ParseArtifactPubKey(pubPEM1) + require.NoError(t, err) + + _, _, pubPEM2, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + pubKey2, err := ParseArtifactPubKey(pubPEM2) + require.NoError(t, err) + + // Bundle both keys + bundle, sigData, err := BundleArtifactKeys(rootKey, []PublicKey{pubKey1, pubKey2}) + require.NoError(t, err) + + sig, err := ParseSignature(sigData) + require.NoError(t, err) + + // Create revocation list with first key revoked + revocationList := &RevocationList{ + Revoked: map[KeyID]time.Time{ + pubKey1.Metadata.ID: time.Now().UTC(), + }, + LastUpdated: time.Now().UTC(), + } + + // Validate - should only return second key + validKeys, err := ValidateArtifactKeys(rootKeys, bundle, *sig, revocationList) + require.NoError(t, err) + assert.Len(t, validKeys, 1) + assert.Equal(t, pubKey2.Metadata.ID, validKeys[0].Metadata.ID) +} + +func TestValidateArtifactKeys_AllRevoked(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + _, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + pubKey, err := ParseArtifactPubKey(pubPEM) + require.NoError(t, err) + + bundle, sigData, err := BundleArtifactKeys(rootKey, []PublicKey{pubKey}) + require.NoError(t, err) + + sig, err := ParseSignature(sigData) + require.NoError(t, err) + + // Revoke the key + revocationList := &RevocationList{ + Revoked: map[KeyID]time.Time{ + pubKey.Metadata.ID: time.Now().UTC(), + }, + LastUpdated: time.Now().UTC(), + } + + _, err = ValidateArtifactKeys(rootKeys, bundle, *sig, revocationList) + assert.Error(t, err) + assert.Contains(t, err.Error(), "revoked") +} + +// Test ValidateArtifact + +func TestValidateArtifact_Valid(t *testing.T) { + // Create root key + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Generate artifact key + artifactKey, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + // Sign some data + data := []byte("test artifact data") + sigData, err := SignData(*artifactKey, data) + require.NoError(t, err) + + sig, err := ParseSignature(sigData) + require.NoError(t, err) + + // Get public key for validation + artifactPubKey := PublicKey{ + Key: artifactKey.Key.Public().(ed25519.PublicKey), + Metadata: artifactKey.Metadata, + } + + // Validate + err = ValidateArtifact([]PublicKey{artifactPubKey}, data, *sig) + assert.NoError(t, err) +} + +func TestValidateArtifact_FutureTimestamp(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + artifactPubKey := PublicKey{ + Key: pub, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + } + + sig := Signature{ + Signature: make([]byte, 64), + Timestamp: time.Now().UTC().Add(10 * time.Minute), + KeyID: computeKeyID(pub), + Algorithm: "ed25519", + HashAlgo: "blake2s", + } + + err = ValidateArtifact([]PublicKey{artifactPubKey}, []byte("data"), sig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "in the future") +} + +func TestValidateArtifact_TooOld(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + artifactPubKey := PublicKey{ + Key: pub, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + } + + sig := Signature{ + Signature: make([]byte, 64), + Timestamp: time.Now().UTC().Add(-20 * 365 * 24 * time.Hour), + KeyID: computeKeyID(pub), + Algorithm: "ed25519", + HashAlgo: "blake2s", + } + + err = ValidateArtifact([]PublicKey{artifactPubKey}, []byte("data"), sig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "too old") +} + +func TestValidateArtifact_ExpiredKey(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Generate artifact key with very short expiration + artifactKey, _, _, _, err := GenerateArtifactKey(rootKey, 1*time.Millisecond) + require.NoError(t, err) + + // Wait for key to expire + time.Sleep(10 * time.Millisecond) + + // Try to sign - should succeed but with old timestamp + data := []byte("test data") + sigData, err := SignData(*artifactKey, data) + require.Error(t, err) // Key is expired, so signing should fail + assert.Contains(t, err.Error(), "expired") + assert.Nil(t, sigData) +} + +func TestValidateArtifact_WrongKey(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Generate two artifact keys + artifactKey1, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + artifactKey2, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + // Sign with key1 + data := []byte("test data") + sigData, err := SignData(*artifactKey1, data) + require.NoError(t, err) + + sig, err := ParseSignature(sigData) + require.NoError(t, err) + + // Try to validate with key2 only + artifactPubKey2 := PublicKey{ + Key: artifactKey2.Key.Public().(ed25519.PublicKey), + Metadata: artifactKey2.Metadata, + } + + err = ValidateArtifact([]PublicKey{artifactPubKey2}, data, *sig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no signing Key found") +} + +func TestValidateArtifact_TamperedData(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + artifactKey, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + // Sign original data + originalData := []byte("original data") + sigData, err := SignData(*artifactKey, originalData) + require.NoError(t, err) + + sig, err := ParseSignature(sigData) + require.NoError(t, err) + + artifactPubKey := PublicKey{ + Key: artifactKey.Key.Public().(ed25519.PublicKey), + Metadata: artifactKey.Metadata, + } + + // Try to validate with tampered data + tamperedData := []byte("tampered data") + err = ValidateArtifact([]PublicKey{artifactPubKey}, tamperedData, *sig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "verification failed") +} + +func TestValidateArtifactKeys_TwoKeysOneExpired(t *testing.T) { + // Test ValidateArtifactKeys with a bundle containing two keys where one is expired + // Should return only the valid key + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Generate first key with very short expiration + _, _, expiredPubPEM, _, err := GenerateArtifactKey(rootKey, 1*time.Millisecond) + require.NoError(t, err) + expiredPubKey, err := ParseArtifactPubKey(expiredPubPEM) + require.NoError(t, err) + + // Wait for first key to expire + time.Sleep(10 * time.Millisecond) + + // Generate second key with normal expiration + _, _, validPubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + validPubKey, err := ParseArtifactPubKey(validPubPEM) + require.NoError(t, err) + + // Bundle both keys together + bundle, sigData, err := BundleArtifactKeys(rootKey, []PublicKey{expiredPubKey, validPubKey}) + require.NoError(t, err) + + sig, err := ParseSignature(sigData) + require.NoError(t, err) + + // ValidateArtifactKeys should return only the valid key + validKeys, err := ValidateArtifactKeys(rootKeys, bundle, *sig, nil) + require.NoError(t, err) + assert.Len(t, validKeys, 1) + assert.Equal(t, validPubKey.Metadata.ID, validKeys[0].Metadata.ID) +} + +func TestValidateArtifactKeys_TwoKeysBothExpired(t *testing.T) { + // Test ValidateArtifactKeys with a bundle containing two expired keys + // Should fail because no valid keys remain + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Generate first key with + _, _, pubPEM1, _, err := GenerateArtifactKey(rootKey, 24*time.Hour) + require.NoError(t, err) + pubKey1, err := ParseArtifactPubKey(pubPEM1) + require.NoError(t, err) + + // Generate second key with very short expiration + _, _, pubPEM2, _, err := GenerateArtifactKey(rootKey, 1*time.Millisecond) + require.NoError(t, err) + pubKey2, err := ParseArtifactPubKey(pubPEM2) + require.NoError(t, err) + + // Wait for expire + time.Sleep(10 * time.Millisecond) + + bundle, sigData, err := BundleArtifactKeys(rootKey, []PublicKey{pubKey1, pubKey2}) + require.NoError(t, err) + + sig, err := ParseSignature(sigData) + require.NoError(t, err) + + // ValidateArtifactKeys should fail because all keys are expired + keys, err := ValidateArtifactKeys(rootKeys, bundle, *sig, nil) + assert.NoError(t, err) + assert.Len(t, keys, 1) +} + +// Test SignData + +func TestSignData_Valid(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + artifactKey, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + data := []byte("test data to sign") + sigData, err := SignData(*artifactKey, data) + require.NoError(t, err) + assert.NotEmpty(t, sigData) + + // Verify signature can be parsed + sig, err := ParseSignature(sigData) + require.NoError(t, err) + assert.NotEmpty(t, sig.Signature) + assert.Equal(t, "ed25519", sig.Algorithm) + assert.Equal(t, "blake2s", sig.HashAlgo) +} + +func TestSignData_EmptyData(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + artifactKey, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + _, err = SignData(*artifactKey, []byte{}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "must be positive") +} + +func TestSignData_ExpiredKey(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Generate key with very short expiration + artifactKey, _, _, _, err := GenerateArtifactKey(rootKey, 1*time.Millisecond) + require.NoError(t, err) + + // Wait for expiration + time.Sleep(10 * time.Millisecond) + + // Try to sign with expired key + _, err = SignData(*artifactKey, []byte("data")) + assert.Error(t, err) + assert.Contains(t, err.Error(), "expired") +} + +// Integration test + +func TestArtifact_FullWorkflow(t *testing.T) { + // Step 1: Create root key + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := &RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Step 2: Generate artifact key + artifactKey, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + // Step 3: Create and validate key bundle + artifactPubKey, err := ParseArtifactPubKey(pubPEM) + require.NoError(t, err) + + bundle, bundleSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey}) + require.NoError(t, err) + + sig, err := ParseSignature(bundleSig) + require.NoError(t, err) + + validKeys, err := ValidateArtifactKeys(rootKeys, bundle, *sig, nil) + require.NoError(t, err) + assert.Len(t, validKeys, 1) + + // Step 4: Sign artifact data + artifactData := []byte("This is my artifact data that needs to be signed") + artifactSig, err := SignData(*artifactKey, artifactData) + require.NoError(t, err) + + // Step 5: Validate artifact + parsedSig, err := ParseSignature(artifactSig) + require.NoError(t, err) + + err = ValidateArtifact(validKeys, artifactData, *parsedSig) + assert.NoError(t, err) +} + +// Helper function for tests +func encodePrivateKey(jsonData []byte, typeTag string) []byte { + return pem.EncodeToMemory(&pem.Block{ + Type: typeTag, + Bytes: jsonData, + }) +} diff --git a/client/internal/updatemanager/reposign/certs/root-pub.pem b/client/internal/updatemanager/reposign/certs/root-pub.pem new file mode 100644 index 000000000..e7c2fd2c0 --- /dev/null +++ b/client/internal/updatemanager/reposign/certs/root-pub.pem @@ -0,0 +1,6 @@ +-----BEGIN ROOT PUBLIC KEY----- +eyJLZXkiOiJoaGIxdGRDSEZNMFBuQWp1b2w2cXJ1QXRFbWFFSlg1QjFsZUNxWmpn +V1pvPSIsIk1ldGFkYXRhIjp7ImlkIjoiOWE0OTg2NmI2MzE2MjNiNCIsImNyZWF0 +ZWRfYXQiOiIyMDI1LTExLTI0VDE3OjE1OjI4LjYyNzE3MzE3MVoiLCJleHBpcmVz +X2F0IjoiMjAzNS0xMS0yMlQxNzoxNToyOC42MjcxNzMxNzFaIn19 +-----END ROOT PUBLIC KEY----- diff --git a/client/internal/updatemanager/reposign/certsdev/root-pub.pem b/client/internal/updatemanager/reposign/certsdev/root-pub.pem new file mode 100644 index 000000000..f7145477b --- /dev/null +++ b/client/internal/updatemanager/reposign/certsdev/root-pub.pem @@ -0,0 +1,6 @@ +-----BEGIN ROOT PUBLIC KEY----- +eyJLZXkiOiJyTDByVTN2MEFOZUNmbDZraitiUUd3TE1waU5CaUJLdVBWSnZtQzgr +ZS84PSIsIk1ldGFkYXRhIjp7ImlkIjoiMTBkNjQyZTY2N2FmMDNkNCIsImNyZWF0 +ZWRfYXQiOiIyMDI1LTExLTIwVDE3OjI5OjI5LjE4MDk0NjMxNloiLCJleHBpcmVz +X2F0IjoiMjAyNi0xMS0yMFQxNzoyOToyOS4xODA5NDYzMTZaIn19 +-----END ROOT PUBLIC KEY----- diff --git a/client/internal/updatemanager/reposign/doc.go b/client/internal/updatemanager/reposign/doc.go new file mode 100644 index 000000000..660b9d11d --- /dev/null +++ b/client/internal/updatemanager/reposign/doc.go @@ -0,0 +1,174 @@ +// Package reposign implements a cryptographic signing and verification system +// for NetBird software update artifacts. It provides a hierarchical key +// management system with support for key rotation, revocation, and secure +// artifact distribution. +// +// # Architecture +// +// The package uses a two-tier key hierarchy: +// +// - Root Keys: Long-lived keys that sign artifact keys. These are embedded +// in the client binary and establish the root of trust. Root keys should +// be kept offline and highly secured. +// +// - Artifact Keys: Short-lived keys that sign release artifacts (binaries, +// packages, etc.). These are rotated regularly and can be revoked if +// compromised. Artifact keys are signed by root keys and distributed via +// a public repository. +// +// This separation allows for operational flexibility: artifact keys can be +// rotated frequently without requiring client updates, while root keys remain +// stable and embedded in the software. +// +// # Cryptographic Primitives +// +// The package uses strong, modern cryptographic algorithms: +// - Ed25519: Fast, secure digital signatures (no timing attacks) +// - BLAKE2s-256: Fast cryptographic hash for artifacts +// - SHA-256: Key ID generation +// - JSON: Structured key and signature serialization +// - PEM: Standard key encoding format +// +// # Security Features +// +// Timestamp Binding: +// - All signatures include cryptographically-bound timestamps +// - Prevents replay attacks and enforces signature freshness +// - Clock skew tolerance: 5 minutes +// +// Key Expiration: +// - All keys have expiration times +// - Expired keys are automatically rejected +// - Signing with an expired key fails immediately +// +// Key Revocation: +// - Compromised keys can be revoked via a signed revocation list +// - Revocation list is checked during artifact validation +// - Revoked keys are filtered out before artifact verification +// +// # File Structure +// +// The package expects the following file layout in the key repository: +// +// signrepo/ +// artifact-key-pub.pem # Bundle of artifact public keys +// artifact-key-pub.pem.sig # Root signature of the bundle +// revocation-list.json # List of revoked key IDs +// revocation-list.json.sig # Root signature of revocation list +// +// And in the artifacts repository: +// +// releases/ +// v0.28.0/ +// netbird-linux-amd64 +// netbird-linux-amd64.sig # Artifact signature +// netbird-darwin-amd64 +// netbird-darwin-amd64.sig +// ... +// +// # Embedded Root Keys +// +// Root public keys are embedded in the client binary at compile time: +// - Production keys: certs/ directory +// - Development keys: certsdev/ directory +// +// The build tag determines which keys are embedded: +// - Production builds: //go:build !devartifactsign +// - Development builds: //go:build devartifactsign +// +// This ensures that development artifacts cannot be verified using production +// keys and vice versa. +// +// # Key Rotation Strategies +// +// Root Key Rotation: +// +// Root keys can be rotated without breaking existing clients by leveraging +// the multi-key verification system. The loadEmbeddedPublicKeys function +// reads ALL files from the certs/ directory and accepts signatures from ANY +// of the embedded root keys. +// +// To rotate root keys: +// +// 1. Generate a new root key pair: +// newRootKey, privPEM, pubPEM, err := GenerateRootKey(10 * 365 * 24 * time.Hour) +// +// 2. Add the new public key to the certs/ directory as a new file: +// certs/ +// root-pub-2024.pem # Old key (keep this!) +// root-pub-2025.pem # New key (add this) +// +// 3. Build new client versions with both keys embedded. The verification +// will accept signatures from either key. +// +// 4. Start signing new artifact keys with the new root key. Old clients +// with only the old root key will reject these, but new clients with +// both keys will accept them. +// +// Each file in certs/ can contain a single key or a bundle of keys (multiple +// PEM blocks). The system will parse all keys from all files and use them +// for verification. This provides maximum flexibility for key management. +// +// Important: Never remove all old root keys at once. Always maintain at least +// one overlapping key between releases to ensure smooth transitions. +// +// Artifact Key Rotation: +// +// Artifact keys should be rotated regularly (e.g., every 90 days) using the +// bundling mechanism. The BundleArtifactKeys function allows multiple artifact +// keys to be bundled together in a single signed package, and ValidateArtifact +// will accept signatures from ANY key in the bundle. +// +// To rotate artifact keys smoothly: +// +// 1. Generate a new artifact key while keeping the old one: +// newKey, newPrivPEM, newPubPEM, newSig, err := GenerateArtifactKey(rootKey, 90 * 24 * time.Hour) +// // Keep oldPubPEM and oldKey available +// +// 2. Create a bundle containing both old and new public keys +// +// 3. Upload the bundle and its signature to the key repository: +// signrepo/artifact-key-pub.pem # Contains both keys +// signrepo/artifact-key-pub.pem.sig # Root signature +// +// 4. Start signing new releases with the NEW key, but keep the bundle +// unchanged. Clients will download the bundle (containing both keys) +// and accept signatures from either key. +// +// Key bundle validation workflow: +// 1. Client downloads artifact-key-pub.pem and artifact-key-pub.pem.sig +// 2. ValidateArtifactKeys verifies the bundle signature with ANY embedded root key +// 3. ValidateArtifactKeys parses all public keys from the bundle +// 4. ValidateArtifactKeys filters out expired or revoked keys +// 5. When verifying an artifact, ValidateArtifact tries each key until one succeeds +// +// This multi-key acceptance model enables overlapping validity periods and +// smooth transitions without client update requirements. +// +// # Best Practices +// +// Root Key Management: +// - Generate root keys offline on an air-gapped machine +// - Store root private keys in hardware security modules (HSM) if possible +// - Use separate root keys for production and development +// - Rotate root keys infrequently (e.g., every 5-10 years) +// - Plan for root key rotation: embed multiple root public keys +// +// Artifact Key Management: +// - Rotate artifact keys regularly (e.g., every 90 days) +// - Use separate artifact keys for different release channels if needed +// - Revoke keys immediately upon suspected compromise +// - Bundle multiple artifact keys to enable smooth rotation +// +// Signing Process: +// - Sign artifacts in a secure CI/CD environment +// - Never commit private keys to version control +// - Use environment variables or secret management for keys +// - Verify signatures immediately after signing +// +// Distribution: +// - Serve keys and revocation lists from a reliable CDN +// - Use HTTPS for all key and artifact downloads +// - Monitor download failures and signature verification failures +// - Keep revocation list up to date +package reposign diff --git a/client/internal/updatemanager/reposign/embed_dev.go b/client/internal/updatemanager/reposign/embed_dev.go new file mode 100644 index 000000000..ef8f77373 --- /dev/null +++ b/client/internal/updatemanager/reposign/embed_dev.go @@ -0,0 +1,10 @@ +//go:build devartifactsign + +package reposign + +import "embed" + +//go:embed certsdev +var embeddedCerts embed.FS + +const embeddedCertsDir = "certsdev" diff --git a/client/internal/updatemanager/reposign/embed_prod.go b/client/internal/updatemanager/reposign/embed_prod.go new file mode 100644 index 000000000..91530e5f4 --- /dev/null +++ b/client/internal/updatemanager/reposign/embed_prod.go @@ -0,0 +1,10 @@ +//go:build !devartifactsign + +package reposign + +import "embed" + +//go:embed certs +var embeddedCerts embed.FS + +const embeddedCertsDir = "certs" diff --git a/client/internal/updatemanager/reposign/key.go b/client/internal/updatemanager/reposign/key.go new file mode 100644 index 000000000..bedfef70d --- /dev/null +++ b/client/internal/updatemanager/reposign/key.go @@ -0,0 +1,171 @@ +package reposign + +import ( + "crypto/ed25519" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "time" +) + +const ( + maxClockSkew = 5 * time.Minute +) + +// KeyID is a unique identifier for a Key (first 8 bytes of SHA-256 of public Key) +type KeyID [8]byte + +// computeKeyID generates a unique ID from a public Key +func computeKeyID(pub ed25519.PublicKey) KeyID { + h := sha256.Sum256(pub) + var id KeyID + copy(id[:], h[:8]) + return id +} + +// MarshalJSON implements json.Marshaler for KeyID +func (k KeyID) MarshalJSON() ([]byte, error) { + return json.Marshal(k.String()) +} + +// UnmarshalJSON implements json.Unmarshaler for KeyID +func (k *KeyID) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + + parsed, err := ParseKeyID(s) + if err != nil { + return err + } + + *k = parsed + return nil +} + +// ParseKeyID parses a hex string (16 hex chars = 8 bytes) into a KeyID. +func ParseKeyID(s string) (KeyID, error) { + var id KeyID + if len(s) != 16 { + return id, fmt.Errorf("invalid KeyID length: got %d, want 16 hex chars (8 bytes)", len(s)) + } + + b, err := hex.DecodeString(s) + if err != nil { + return id, fmt.Errorf("failed to decode KeyID: %w", err) + } + + copy(id[:], b) + return id, nil +} + +func (k KeyID) String() string { + return fmt.Sprintf("%x", k[:]) +} + +// KeyMetadata contains versioning and lifecycle information for a Key +type KeyMetadata struct { + ID KeyID `json:"id"` + CreatedAt time.Time `json:"created_at"` + ExpiresAt time.Time `json:"expires_at,omitempty"` // Optional expiration +} + +// PublicKey wraps a public Key with its Metadata +type PublicKey struct { + Key ed25519.PublicKey + Metadata KeyMetadata +} + +func parsePublicKeyBundle(bundle []byte, typeTag string) ([]PublicKey, error) { + var keys []PublicKey + for len(bundle) > 0 { + keyInfo, rest, err := parsePublicKey(bundle, typeTag) + if err != nil { + return nil, err + } + keys = append(keys, keyInfo) + bundle = rest + } + if len(keys) == 0 { + return nil, errors.New("no keys found in bundle") + } + return keys, nil +} + +func parsePublicKey(data []byte, typeTag string) (PublicKey, []byte, error) { + b, rest := pem.Decode(data) + if b == nil { + return PublicKey{}, nil, errors.New("failed to decode PEM data") + } + if b.Type != typeTag { + return PublicKey{}, nil, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag) + } + + // Unmarshal JSON-embedded format + var pub PublicKey + if err := json.Unmarshal(b.Bytes, &pub); err != nil { + return PublicKey{}, nil, fmt.Errorf("failed to unmarshal public key: %w", err) + } + + // Validate key length + if len(pub.Key) != ed25519.PublicKeySize { + return PublicKey{}, nil, fmt.Errorf("incorrect Ed25519 public key size: expected %d, got %d", + ed25519.PublicKeySize, len(pub.Key)) + } + + // Always recompute ID to ensure integrity + pub.Metadata.ID = computeKeyID(pub.Key) + + return pub, rest, nil +} + +type PrivateKey struct { + Key ed25519.PrivateKey + Metadata KeyMetadata +} + +func parsePrivateKey(data []byte, typeTag string) (PrivateKey, error) { + b, rest := pem.Decode(data) + if b == nil { + return PrivateKey{}, errors.New("failed to decode PEM data") + } + if len(rest) > 0 { + return PrivateKey{}, errors.New("trailing PEM data") + } + if b.Type != typeTag { + return PrivateKey{}, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag) + } + + // Unmarshal JSON-embedded format + var pk PrivateKey + if err := json.Unmarshal(b.Bytes, &pk); err != nil { + return PrivateKey{}, fmt.Errorf("failed to unmarshal private key: %w", err) + } + + // Validate key length + if len(pk.Key) != ed25519.PrivateKeySize { + return PrivateKey{}, fmt.Errorf("incorrect Ed25519 private key size: expected %d, got %d", + ed25519.PrivateKeySize, len(pk.Key)) + } + + return pk, nil +} + +func verifyAny(publicRootKeys []PublicKey, msg, sig []byte) bool { + // Verify with root keys + var rootKeys []ed25519.PublicKey + for _, r := range publicRootKeys { + rootKeys = append(rootKeys, r.Key) + } + + for _, k := range rootKeys { + if ed25519.Verify(k, msg, sig) { + return true + } + } + return false +} diff --git a/client/internal/updatemanager/reposign/key_test.go b/client/internal/updatemanager/reposign/key_test.go new file mode 100644 index 000000000..f8e1676fb --- /dev/null +++ b/client/internal/updatemanager/reposign/key_test.go @@ -0,0 +1,636 @@ +package reposign + +import ( + "crypto/ed25519" + "crypto/rand" + "crypto/sha256" + "encoding/json" + "encoding/pem" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test KeyID functions + +func TestComputeKeyID(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + keyID := computeKeyID(pub) + + // Verify it's the first 8 bytes of SHA-256 + h := sha256.Sum256(pub) + expectedID := KeyID{} + copy(expectedID[:], h[:8]) + + assert.Equal(t, expectedID, keyID) +} + +func TestComputeKeyID_Deterministic(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + // Computing KeyID multiple times should give the same result + keyID1 := computeKeyID(pub) + keyID2 := computeKeyID(pub) + + assert.Equal(t, keyID1, keyID2) +} + +func TestComputeKeyID_DifferentKeys(t *testing.T) { + pub1, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + pub2, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + keyID1 := computeKeyID(pub1) + keyID2 := computeKeyID(pub2) + + // Different keys should produce different IDs + assert.NotEqual(t, keyID1, keyID2) +} + +func TestParseKeyID_Valid(t *testing.T) { + hexStr := "0123456789abcdef" + + keyID, err := ParseKeyID(hexStr) + require.NoError(t, err) + + expected := KeyID{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef} + assert.Equal(t, expected, keyID) +} + +func TestParseKeyID_InvalidLength(t *testing.T) { + tests := []struct { + name string + input string + }{ + {"too short", "01234567"}, + {"too long", "0123456789abcdef00"}, + {"empty", ""}, + {"odd length", "0123456789abcde"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := ParseKeyID(tt.input) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid KeyID length") + }) + } +} + +func TestParseKeyID_InvalidHex(t *testing.T) { + invalidHex := "0123456789abcxyz" // 'xyz' are not valid hex + + _, err := ParseKeyID(invalidHex) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to decode KeyID") +} + +func TestKeyID_String(t *testing.T) { + keyID := KeyID{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef} + + str := keyID.String() + assert.Equal(t, "0123456789abcdef", str) +} + +func TestKeyID_RoundTrip(t *testing.T) { + original := "fedcba9876543210" + + keyID, err := ParseKeyID(original) + require.NoError(t, err) + + result := keyID.String() + assert.Equal(t, original, result) +} + +func TestKeyID_ZeroValue(t *testing.T) { + keyID := KeyID{} + str := keyID.String() + assert.Equal(t, "0000000000000000", str) +} + +// Test KeyMetadata + +func TestKeyMetadata_JSONMarshaling(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + metadata := KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC), + ExpiresAt: time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC), + } + + jsonData, err := json.Marshal(metadata) + require.NoError(t, err) + + var decoded KeyMetadata + err = json.Unmarshal(jsonData, &decoded) + require.NoError(t, err) + + assert.Equal(t, metadata.ID, decoded.ID) + assert.Equal(t, metadata.CreatedAt.Unix(), decoded.CreatedAt.Unix()) + assert.Equal(t, metadata.ExpiresAt.Unix(), decoded.ExpiresAt.Unix()) +} + +func TestKeyMetadata_NoExpiration(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + metadata := KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC), + ExpiresAt: time.Time{}, // Zero value = no expiration + } + + jsonData, err := json.Marshal(metadata) + require.NoError(t, err) + + var decoded KeyMetadata + err = json.Unmarshal(jsonData, &decoded) + require.NoError(t, err) + + assert.True(t, decoded.ExpiresAt.IsZero()) +} + +// Test PublicKey + +func TestPublicKey_JSONMarshaling(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + pubKey := PublicKey{ + Key: pub, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().Add(365 * 24 * time.Hour).UTC(), + }, + } + + jsonData, err := json.Marshal(pubKey) + require.NoError(t, err) + + var decoded PublicKey + err = json.Unmarshal(jsonData, &decoded) + require.NoError(t, err) + + assert.Equal(t, pubKey.Key, decoded.Key) + assert.Equal(t, pubKey.Metadata.ID, decoded.Metadata.ID) +} + +// Test parsePublicKey + +func TestParsePublicKey_Valid(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + metadata := KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().Add(365 * 24 * time.Hour).UTC(), + } + + pubKey := PublicKey{ + Key: pub, + Metadata: metadata, + } + + // Marshal to JSON + jsonData, err := json.Marshal(pubKey) + require.NoError(t, err) + + // Encode to PEM + pemData := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPublic, + Bytes: jsonData, + }) + + // Parse it back + parsed, rest, err := parsePublicKey(pemData, tagRootPublic) + require.NoError(t, err) + assert.Empty(t, rest) + assert.Equal(t, pub, parsed.Key) + assert.Equal(t, metadata.ID, parsed.Metadata.ID) +} + +func TestParsePublicKey_InvalidPEM(t *testing.T) { + invalidPEM := []byte("not a PEM") + + _, _, err := parsePublicKey(invalidPEM, tagRootPublic) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to decode PEM") +} + +func TestParsePublicKey_WrongType(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + pubKey := PublicKey{ + Key: pub, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + } + + jsonData, err := json.Marshal(pubKey) + require.NoError(t, err) + + // Encode with wrong type + pemData := pem.EncodeToMemory(&pem.Block{ + Type: "WRONG TYPE", + Bytes: jsonData, + }) + + _, _, err = parsePublicKey(pemData, tagRootPublic) + assert.Error(t, err) + assert.Contains(t, err.Error(), "PEM type") +} + +func TestParsePublicKey_InvalidJSON(t *testing.T) { + pemData := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPublic, + Bytes: []byte("invalid json"), + }) + + _, _, err := parsePublicKey(pemData, tagRootPublic) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to unmarshal") +} + +func TestParsePublicKey_InvalidKeySize(t *testing.T) { + // Create a public key with wrong size + pubKey := PublicKey{ + Key: []byte{0x01, 0x02, 0x03}, // Too short + Metadata: KeyMetadata{ + ID: KeyID{}, + CreatedAt: time.Now().UTC(), + }, + } + + jsonData, err := json.Marshal(pubKey) + require.NoError(t, err) + + pemData := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPublic, + Bytes: jsonData, + }) + + _, _, err = parsePublicKey(pemData, tagRootPublic) + assert.Error(t, err) + assert.Contains(t, err.Error(), "incorrect Ed25519 public key size") +} + +func TestParsePublicKey_IDRecomputation(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + // Create a public key with WRONG ID + wrongID := KeyID{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff} + pubKey := PublicKey{ + Key: pub, + Metadata: KeyMetadata{ + ID: wrongID, + CreatedAt: time.Now().UTC(), + }, + } + + jsonData, err := json.Marshal(pubKey) + require.NoError(t, err) + + pemData := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPublic, + Bytes: jsonData, + }) + + // Parse should recompute the correct ID + parsed, _, err := parsePublicKey(pemData, tagRootPublic) + require.NoError(t, err) + + correctID := computeKeyID(pub) + assert.Equal(t, correctID, parsed.Metadata.ID) + assert.NotEqual(t, wrongID, parsed.Metadata.ID) +} + +// Test parsePublicKeyBundle + +func TestParsePublicKeyBundle_Single(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + pubKey := PublicKey{ + Key: pub, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + } + + jsonData, err := json.Marshal(pubKey) + require.NoError(t, err) + + pemData := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPublic, + Bytes: jsonData, + }) + + keys, err := parsePublicKeyBundle(pemData, tagRootPublic) + require.NoError(t, err) + assert.Len(t, keys, 1) + assert.Equal(t, pub, keys[0].Key) +} + +func TestParsePublicKeyBundle_Multiple(t *testing.T) { + var bundle []byte + + // Create 3 keys + for i := 0; i < 3; i++ { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + pubKey := PublicKey{ + Key: pub, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + } + + jsonData, err := json.Marshal(pubKey) + require.NoError(t, err) + + pemData := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPublic, + Bytes: jsonData, + }) + + bundle = append(bundle, pemData...) + } + + keys, err := parsePublicKeyBundle(bundle, tagRootPublic) + require.NoError(t, err) + assert.Len(t, keys, 3) +} + +func TestParsePublicKeyBundle_Empty(t *testing.T) { + _, err := parsePublicKeyBundle([]byte{}, tagRootPublic) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no keys found") +} + +func TestParsePublicKeyBundle_Invalid(t *testing.T) { + _, err := parsePublicKeyBundle([]byte("invalid data"), tagRootPublic) + assert.Error(t, err) +} + +// Test PrivateKey + +func TestPrivateKey_JSONMarshaling(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + privKey := PrivateKey{ + Key: priv, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + } + + jsonData, err := json.Marshal(privKey) + require.NoError(t, err) + + var decoded PrivateKey + err = json.Unmarshal(jsonData, &decoded) + require.NoError(t, err) + + assert.Equal(t, privKey.Key, decoded.Key) + assert.Equal(t, privKey.Metadata.ID, decoded.Metadata.ID) +} + +// Test parsePrivateKey + +func TestParsePrivateKey_Valid(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + privKey := PrivateKey{ + Key: priv, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + } + + jsonData, err := json.Marshal(privKey) + require.NoError(t, err) + + pemData := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPrivate, + Bytes: jsonData, + }) + + parsed, err := parsePrivateKey(pemData, tagRootPrivate) + require.NoError(t, err) + assert.Equal(t, priv, parsed.Key) +} + +func TestParsePrivateKey_InvalidPEM(t *testing.T) { + _, err := parsePrivateKey([]byte("not a PEM"), tagRootPrivate) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to decode PEM") +} + +func TestParsePrivateKey_TrailingData(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + privKey := PrivateKey{ + Key: priv, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + } + + jsonData, err := json.Marshal(privKey) + require.NoError(t, err) + + pemData := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPrivate, + Bytes: jsonData, + }) + + // Add trailing data + pemData = append(pemData, []byte("extra data")...) + + _, err = parsePrivateKey(pemData, tagRootPrivate) + assert.Error(t, err) + assert.Contains(t, err.Error(), "trailing PEM data") +} + +func TestParsePrivateKey_WrongType(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + privKey := PrivateKey{ + Key: priv, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + } + + jsonData, err := json.Marshal(privKey) + require.NoError(t, err) + + pemData := pem.EncodeToMemory(&pem.Block{ + Type: "WRONG TYPE", + Bytes: jsonData, + }) + + _, err = parsePrivateKey(pemData, tagRootPrivate) + assert.Error(t, err) + assert.Contains(t, err.Error(), "PEM type") +} + +func TestParsePrivateKey_InvalidKeySize(t *testing.T) { + privKey := PrivateKey{ + Key: []byte{0x01, 0x02, 0x03}, // Too short + Metadata: KeyMetadata{ + ID: KeyID{}, + CreatedAt: time.Now().UTC(), + }, + } + + jsonData, err := json.Marshal(privKey) + require.NoError(t, err) + + pemData := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPrivate, + Bytes: jsonData, + }) + + _, err = parsePrivateKey(pemData, tagRootPrivate) + assert.Error(t, err) + assert.Contains(t, err.Error(), "incorrect Ed25519 private key size") +} + +// Test verifyAny + +func TestVerifyAny_ValidSignature(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + message := []byte("test message") + signature := ed25519.Sign(priv, message) + + rootKeys := []PublicKey{ + { + Key: pub, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + result := verifyAny(rootKeys, message, signature) + assert.True(t, result) +} + +func TestVerifyAny_InvalidSignature(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + message := []byte("test message") + invalidSignature := make([]byte, ed25519.SignatureSize) + + rootKeys := []PublicKey{ + { + Key: pub, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + result := verifyAny(rootKeys, message, invalidSignature) + assert.False(t, result) +} + +func TestVerifyAny_MultipleKeys(t *testing.T) { + // Create 3 key pairs + pub1, priv1, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + pub2, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + pub3, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + message := []byte("test message") + signature := ed25519.Sign(priv1, message) + + rootKeys := []PublicKey{ + {Key: pub2, Metadata: KeyMetadata{ID: computeKeyID(pub2)}}, + {Key: pub1, Metadata: KeyMetadata{ID: computeKeyID(pub1)}}, // Correct key in middle + {Key: pub3, Metadata: KeyMetadata{ID: computeKeyID(pub3)}}, + } + + result := verifyAny(rootKeys, message, signature) + assert.True(t, result) +} + +func TestVerifyAny_NoMatchingKey(t *testing.T) { + _, priv1, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + pub2, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + message := []byte("test message") + signature := ed25519.Sign(priv1, message) + + // Only include pub2, not pub1 + rootKeys := []PublicKey{ + {Key: pub2, Metadata: KeyMetadata{ID: computeKeyID(pub2)}}, + } + + result := verifyAny(rootKeys, message, signature) + assert.False(t, result) +} + +func TestVerifyAny_EmptyKeys(t *testing.T) { + message := []byte("test message") + signature := make([]byte, ed25519.SignatureSize) + + result := verifyAny([]PublicKey{}, message, signature) + assert.False(t, result) +} + +func TestVerifyAny_TamperedMessage(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + message := []byte("test message") + signature := ed25519.Sign(priv, message) + + rootKeys := []PublicKey{ + {Key: pub, Metadata: KeyMetadata{ID: computeKeyID(pub)}}, + } + + // Verify with different message + tamperedMessage := []byte("different message") + result := verifyAny(rootKeys, tamperedMessage, signature) + assert.False(t, result) +} diff --git a/client/internal/updatemanager/reposign/revocation.go b/client/internal/updatemanager/reposign/revocation.go new file mode 100644 index 000000000..e679e212f --- /dev/null +++ b/client/internal/updatemanager/reposign/revocation.go @@ -0,0 +1,229 @@ +package reposign + +import ( + "crypto/ed25519" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + maxRevocationSignatureAge = 10 * 365 * 24 * time.Hour + defaultRevocationListExpiration = 365 * 24 * time.Hour +) + +type RevocationList struct { + Revoked map[KeyID]time.Time `json:"revoked"` // KeyID -> revocation time + LastUpdated time.Time `json:"last_updated"` // When the list was last modified + ExpiresAt time.Time `json:"expires_at"` // When the list expires +} + +func (rl RevocationList) MarshalJSON() ([]byte, error) { + // Convert map[KeyID]time.Time to map[string]time.Time + strMap := make(map[string]time.Time, len(rl.Revoked)) + for k, v := range rl.Revoked { + strMap[k.String()] = v + } + + return json.Marshal(map[string]interface{}{ + "revoked": strMap, + "last_updated": rl.LastUpdated, + "expires_at": rl.ExpiresAt, + }) +} + +func (rl *RevocationList) UnmarshalJSON(data []byte) error { + var temp struct { + Revoked map[string]time.Time `json:"revoked"` + LastUpdated time.Time `json:"last_updated"` + ExpiresAt time.Time `json:"expires_at"` + Version int `json:"version"` + } + + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + // Convert map[string]time.Time back to map[KeyID]time.Time + rl.Revoked = make(map[KeyID]time.Time, len(temp.Revoked)) + for k, v := range temp.Revoked { + kid, err := ParseKeyID(k) + if err != nil { + return fmt.Errorf("failed to parse KeyID %q: %w", k, err) + } + rl.Revoked[kid] = v + } + + rl.LastUpdated = temp.LastUpdated + rl.ExpiresAt = temp.ExpiresAt + + return nil +} + +func ParseRevocationList(data []byte) (*RevocationList, error) { + var rl RevocationList + if err := json.Unmarshal(data, &rl); err != nil { + return nil, fmt.Errorf("failed to unmarshal revocation list: %w", err) + } + + // Initialize the map if it's nil (in case of empty JSON object) + if rl.Revoked == nil { + rl.Revoked = make(map[KeyID]time.Time) + } + + if rl.LastUpdated.IsZero() { + return nil, fmt.Errorf("revocation list missing last_updated timestamp") + } + + if rl.ExpiresAt.IsZero() { + return nil, fmt.Errorf("revocation list missing expires_at timestamp") + } + + return &rl, nil +} + +func ValidateRevocationList(publicRootKeys []PublicKey, data []byte, signature Signature) (*RevocationList, error) { + revoList, err := ParseRevocationList(data) + if err != nil { + log.Debugf("failed to parse revocation list: %s", err) + return nil, err + } + + now := time.Now().UTC() + + // Validate signature timestamp + if signature.Timestamp.After(now.Add(maxClockSkew)) { + err := fmt.Errorf("revocation signature timestamp is in the future: %v", signature.Timestamp) + log.Debugf("revocation list signature error: %v", err) + return nil, err + } + + if now.Sub(signature.Timestamp) > maxRevocationSignatureAge { + err := fmt.Errorf("revocation list signature is too old: %v (created %v)", + now.Sub(signature.Timestamp), signature.Timestamp) + log.Debugf("revocation list signature error: %v", err) + return nil, err + } + + // Ensure LastUpdated is not in the future (with clock skew tolerance) + if revoList.LastUpdated.After(now.Add(maxClockSkew)) { + err := fmt.Errorf("revocation list LastUpdated is in the future: %v", revoList.LastUpdated) + log.Errorf("rejecting future-dated revocation list: %v", err) + return nil, err + } + + // Check if the revocation list has expired + if now.After(revoList.ExpiresAt) { + err := fmt.Errorf("revocation list expired at %v (current time: %v)", revoList.ExpiresAt, now) + log.Errorf("rejecting expired revocation list: %v", err) + return nil, err + } + + // Ensure ExpiresAt is not in the future by more than the expected expiration window + // (allows some clock skew but prevents maliciously long expiration times) + if revoList.ExpiresAt.After(now.Add(maxRevocationSignatureAge)) { + err := fmt.Errorf("revocation list ExpiresAt is too far in the future: %v", revoList.ExpiresAt) + log.Errorf("rejecting revocation list with invalid expiration: %v", err) + return nil, err + } + + // Validate signature timestamp is close to LastUpdated + // (prevents signing old lists with new timestamps) + timeDiff := signature.Timestamp.Sub(revoList.LastUpdated).Abs() + if timeDiff > maxClockSkew { + err := fmt.Errorf("signature timestamp %v differs too much from list LastUpdated %v (diff: %v)", + signature.Timestamp, revoList.LastUpdated, timeDiff) + log.Errorf("timestamp mismatch in revocation list: %v", err) + return nil, err + } + + // Reconstruct the signed message: revocation_list_data || timestamp || version + msg := make([]byte, 0, len(data)+8) + msg = append(msg, data...) + msg = binary.LittleEndian.AppendUint64(msg, uint64(signature.Timestamp.Unix())) + + if !verifyAny(publicRootKeys, msg, signature.Signature) { + return nil, errors.New("revocation list verification failed") + } + return revoList, nil +} + +func CreateRevocationList(privateRootKey RootKey, expiration time.Duration) ([]byte, []byte, error) { + now := time.Now() + rl := RevocationList{ + Revoked: make(map[KeyID]time.Time), + LastUpdated: now.UTC(), + ExpiresAt: now.Add(expiration).UTC(), + } + + signature, err := signRevocationList(privateRootKey, rl) + if err != nil { + return nil, nil, fmt.Errorf("failed to sign revocation list: %w", err) + } + + rlData, err := json.Marshal(&rl) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal revocation list: %w", err) + } + + signData, err := json.Marshal(signature) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal signature: %w", err) + } + + return rlData, signData, nil +} + +func ExtendRevocationList(privateRootKey RootKey, rl RevocationList, kid KeyID, expiration time.Duration) ([]byte, []byte, error) { + now := time.Now().UTC() + + rl.Revoked[kid] = now + rl.LastUpdated = now + rl.ExpiresAt = now.Add(expiration) + + signature, err := signRevocationList(privateRootKey, rl) + if err != nil { + return nil, nil, fmt.Errorf("failed to sign revocation list: %w", err) + } + + rlData, err := json.Marshal(&rl) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal revocation list: %w", err) + } + + signData, err := json.Marshal(signature) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal signature: %w", err) + } + + return rlData, signData, nil +} + +func signRevocationList(privateRootKey RootKey, rl RevocationList) (*Signature, error) { + data, err := json.Marshal(rl) + if err != nil { + return nil, fmt.Errorf("failed to marshal revocation list for signing: %w", err) + } + + timestamp := time.Now().UTC() + + msg := make([]byte, 0, len(data)+8) + msg = append(msg, data...) + msg = binary.LittleEndian.AppendUint64(msg, uint64(timestamp.Unix())) + + sig := ed25519.Sign(privateRootKey.Key, msg) + + signature := &Signature{ + Signature: sig, + Timestamp: timestamp, + KeyID: privateRootKey.Metadata.ID, + Algorithm: "ed25519", + HashAlgo: "sha512", + } + + return signature, nil +} diff --git a/client/internal/updatemanager/reposign/revocation_test.go b/client/internal/updatemanager/reposign/revocation_test.go new file mode 100644 index 000000000..d6d748f3d --- /dev/null +++ b/client/internal/updatemanager/reposign/revocation_test.go @@ -0,0 +1,860 @@ +package reposign + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test RevocationList marshaling/unmarshaling + +func TestRevocationList_MarshalJSON(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + keyID := computeKeyID(pub) + revokedTime := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + lastUpdated := time.Date(2024, 1, 15, 11, 0, 0, 0, time.UTC) + expiresAt := time.Date(2024, 4, 15, 11, 0, 0, 0, time.UTC) + + rl := &RevocationList{ + Revoked: map[KeyID]time.Time{ + keyID: revokedTime, + }, + LastUpdated: lastUpdated, + ExpiresAt: expiresAt, + } + + jsonData, err := json.Marshal(rl) + require.NoError(t, err) + + // Verify it can be unmarshaled back + var decoded map[string]interface{} + err = json.Unmarshal(jsonData, &decoded) + require.NoError(t, err) + + assert.Contains(t, decoded, "revoked") + assert.Contains(t, decoded, "last_updated") + assert.Contains(t, decoded, "expires_at") +} + +func TestRevocationList_UnmarshalJSON(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + keyID := computeKeyID(pub) + revokedTime := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + lastUpdated := time.Date(2024, 1, 15, 11, 0, 0, 0, time.UTC) + + jsonData := map[string]interface{}{ + "revoked": map[string]string{ + keyID.String(): revokedTime.Format(time.RFC3339), + }, + "last_updated": lastUpdated.Format(time.RFC3339), + } + + jsonBytes, err := json.Marshal(jsonData) + require.NoError(t, err) + + var rl RevocationList + err = json.Unmarshal(jsonBytes, &rl) + require.NoError(t, err) + + assert.Len(t, rl.Revoked, 1) + assert.Contains(t, rl.Revoked, keyID) + assert.Equal(t, lastUpdated.Unix(), rl.LastUpdated.Unix()) +} + +func TestRevocationList_MarshalUnmarshal_Roundtrip(t *testing.T) { + pub1, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + pub2, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + keyID1 := computeKeyID(pub1) + keyID2 := computeKeyID(pub2) + + original := &RevocationList{ + Revoked: map[KeyID]time.Time{ + keyID1: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC), + keyID2: time.Date(2024, 2, 20, 14, 45, 0, 0, time.UTC), + }, + LastUpdated: time.Date(2024, 2, 20, 15, 0, 0, 0, time.UTC), + } + + // Marshal + jsonData, err := original.MarshalJSON() + require.NoError(t, err) + + // Unmarshal + var decoded RevocationList + err = decoded.UnmarshalJSON(jsonData) + require.NoError(t, err) + + // Verify + assert.Len(t, decoded.Revoked, 2) + assert.Equal(t, original.Revoked[keyID1].Unix(), decoded.Revoked[keyID1].Unix()) + assert.Equal(t, original.Revoked[keyID2].Unix(), decoded.Revoked[keyID2].Unix()) + assert.Equal(t, original.LastUpdated.Unix(), decoded.LastUpdated.Unix()) +} + +func TestRevocationList_UnmarshalJSON_InvalidKeyID(t *testing.T) { + jsonData := []byte(`{ + "revoked": { + "invalid_key_id": "2024-01-15T10:30:00Z" + }, + "last_updated": "2024-01-15T11:00:00Z" + }`) + + var rl RevocationList + err := json.Unmarshal(jsonData, &rl) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse KeyID") +} + +func TestRevocationList_EmptyRevoked(t *testing.T) { + rl := &RevocationList{ + Revoked: make(map[KeyID]time.Time), + LastUpdated: time.Now().UTC(), + } + + jsonData, err := rl.MarshalJSON() + require.NoError(t, err) + + var decoded RevocationList + err = decoded.UnmarshalJSON(jsonData) + require.NoError(t, err) + + assert.Empty(t, decoded.Revoked) + assert.NotNil(t, decoded.Revoked) +} + +// Test ParseRevocationList + +func TestParseRevocationList_Valid(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + keyID := computeKeyID(pub) + revokedTime := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + lastUpdated := time.Date(2024, 1, 15, 11, 0, 0, 0, time.UTC) + + rl := RevocationList{ + Revoked: map[KeyID]time.Time{ + keyID: revokedTime, + }, + LastUpdated: lastUpdated, + ExpiresAt: time.Date(2025, 2, 20, 14, 45, 0, 0, time.UTC), + } + + jsonData, err := rl.MarshalJSON() + require.NoError(t, err) + + parsed, err := ParseRevocationList(jsonData) + require.NoError(t, err) + assert.NotNil(t, parsed) + assert.Len(t, parsed.Revoked, 1) + assert.Equal(t, lastUpdated.Unix(), parsed.LastUpdated.Unix()) +} + +func TestParseRevocationList_InvalidJSON(t *testing.T) { + invalidJSON := []byte("not valid json") + + _, err := ParseRevocationList(invalidJSON) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to unmarshal") +} + +func TestParseRevocationList_MissingLastUpdated(t *testing.T) { + jsonData := []byte(`{ + "revoked": {} + }`) + + _, err := ParseRevocationList(jsonData) + assert.Error(t, err) + assert.Contains(t, err.Error(), "missing last_updated") +} + +func TestParseRevocationList_EmptyObject(t *testing.T) { + jsonData := []byte(`{}`) + + _, err := ParseRevocationList(jsonData) + assert.Error(t, err) + assert.Contains(t, err.Error(), "missing last_updated") +} + +func TestParseRevocationList_NilRevoked(t *testing.T) { + lastUpdated := time.Now().UTC() + expiresAt := lastUpdated.Add(90 * 24 * time.Hour) + jsonData := []byte(`{ + "last_updated": "` + lastUpdated.Format(time.RFC3339) + `", + "expires_at": "` + expiresAt.Format(time.RFC3339) + `" + }`) + + parsed, err := ParseRevocationList(jsonData) + require.NoError(t, err) + assert.NotNil(t, parsed.Revoked) + assert.Empty(t, parsed.Revoked) +} + +func TestParseRevocationList_MissingExpiresAt(t *testing.T) { + lastUpdated := time.Now().UTC() + jsonData := []byte(`{ + "revoked": {}, + "last_updated": "` + lastUpdated.Format(time.RFC3339) + `" + }`) + + _, err := ParseRevocationList(jsonData) + assert.Error(t, err) + assert.Contains(t, err.Error(), "missing expires_at") +} + +// Test ValidateRevocationList + +func TestValidateRevocationList_Valid(t *testing.T) { + // Generate root key + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Create revocation list + rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + signature, err := ParseSignature(sigData) + require.NoError(t, err) + + // Validate + rl, err := ValidateRevocationList(rootKeys, rlData, *signature) + require.NoError(t, err) + assert.NotNil(t, rl) + assert.Empty(t, rl.Revoked) +} + +func TestValidateRevocationList_InvalidSignature(t *testing.T) { + // Generate root key + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Create revocation list + rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + // Create invalid signature + invalidSig := Signature{ + Signature: make([]byte, 64), + Timestamp: time.Now().UTC(), + KeyID: computeKeyID(rootPub), + Algorithm: "ed25519", + HashAlgo: "sha512", + } + + // Validate should fail + _, err = ValidateRevocationList(rootKeys, rlData, invalidSig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "verification failed") +} + +func TestValidateRevocationList_FutureTimestamp(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + signature, err := ParseSignature(sigData) + require.NoError(t, err) + + // Modify timestamp to be in the future + signature.Timestamp = time.Now().UTC().Add(10 * time.Minute) + + _, err = ValidateRevocationList(rootKeys, rlData, *signature) + assert.Error(t, err) + assert.Contains(t, err.Error(), "in the future") +} + +func TestValidateRevocationList_TooOld(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + signature, err := ParseSignature(sigData) + require.NoError(t, err) + + // Modify timestamp to be too old + signature.Timestamp = time.Now().UTC().Add(-20 * 365 * 24 * time.Hour) + + _, err = ValidateRevocationList(rootKeys, rlData, *signature) + assert.Error(t, err) + assert.Contains(t, err.Error(), "too old") +} + +func TestValidateRevocationList_InvalidJSON(t *testing.T) { + rootPub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + signature := Signature{ + Signature: make([]byte, 64), + Timestamp: time.Now().UTC(), + KeyID: computeKeyID(rootPub), + Algorithm: "ed25519", + HashAlgo: "sha512", + } + + _, err = ValidateRevocationList(rootKeys, []byte("invalid json"), signature) + assert.Error(t, err) +} + +func TestValidateRevocationList_FutureLastUpdated(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Create revocation list with future LastUpdated + rl := RevocationList{ + Revoked: make(map[KeyID]time.Time), + LastUpdated: time.Now().UTC().Add(10 * time.Minute), + ExpiresAt: time.Now().UTC().Add(365 * 24 * time.Hour), + } + + rlData, err := json.Marshal(rl) + require.NoError(t, err) + + // Sign it + sig, err := signRevocationList(rootKey, rl) + require.NoError(t, err) + + _, err = ValidateRevocationList(rootKeys, rlData, *sig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "LastUpdated is in the future") +} + +func TestValidateRevocationList_TimestampMismatch(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Create revocation list with LastUpdated far in the past + rl := RevocationList{ + Revoked: make(map[KeyID]time.Time), + LastUpdated: time.Now().UTC().Add(-1 * time.Hour), + ExpiresAt: time.Now().UTC().Add(365 * 24 * time.Hour), + } + + rlData, err := json.Marshal(rl) + require.NoError(t, err) + + // Sign it with current timestamp + sig, err := signRevocationList(rootKey, rl) + require.NoError(t, err) + + // Modify signature timestamp to differ too much from LastUpdated + sig.Timestamp = time.Now().UTC() + + _, err = ValidateRevocationList(rootKeys, rlData, *sig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "differs too much") +} + +func TestValidateRevocationList_Expired(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Create revocation list that expired in the past + now := time.Now().UTC() + rl := RevocationList{ + Revoked: make(map[KeyID]time.Time), + LastUpdated: now.Add(-100 * 24 * time.Hour), + ExpiresAt: now.Add(-10 * 24 * time.Hour), // Expired 10 days ago + } + + rlData, err := json.Marshal(rl) + require.NoError(t, err) + + // Sign it + sig, err := signRevocationList(rootKey, rl) + require.NoError(t, err) + // Adjust signature timestamp to match LastUpdated + sig.Timestamp = rl.LastUpdated + + _, err = ValidateRevocationList(rootKeys, rlData, *sig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "expired") +} + +func TestValidateRevocationList_ExpiresAtTooFarInFuture(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Create revocation list with ExpiresAt too far in the future (beyond maxRevocationSignatureAge) + now := time.Now().UTC() + rl := RevocationList{ + Revoked: make(map[KeyID]time.Time), + LastUpdated: now, + ExpiresAt: now.Add(15 * 365 * 24 * time.Hour), // 15 years in the future + } + + rlData, err := json.Marshal(rl) + require.NoError(t, err) + + // Sign it + sig, err := signRevocationList(rootKey, rl) + require.NoError(t, err) + + _, err = ValidateRevocationList(rootKeys, rlData, *sig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "too far in the future") +} + +// Test CreateRevocationList + +func TestCreateRevocationList_Valid(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + assert.NotEmpty(t, rlData) + assert.NotEmpty(t, sigData) + + // Verify it can be parsed + rl, err := ParseRevocationList(rlData) + require.NoError(t, err) + assert.Empty(t, rl.Revoked) + assert.False(t, rl.LastUpdated.IsZero()) + + // Verify signature can be parsed + sig, err := ParseSignature(sigData) + require.NoError(t, err) + assert.NotEmpty(t, sig.Signature) +} + +// Test ExtendRevocationList + +func TestExtendRevocationList_AddKey(t *testing.T) { + // Generate root key + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Create empty revocation list + rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + rl, err := ParseRevocationList(rlData) + require.NoError(t, err) + assert.Empty(t, rl.Revoked) + + // Generate a key to revoke + revokedPub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + revokedKeyID := computeKeyID(revokedPub) + + // Extend the revocation list + newRLData, newSigData, err := ExtendRevocationList(rootKey, *rl, revokedKeyID, defaultRevocationListExpiration) + require.NoError(t, err) + + // Verify the new list + newRL, err := ParseRevocationList(newRLData) + require.NoError(t, err) + assert.Len(t, newRL.Revoked, 1) + assert.Contains(t, newRL.Revoked, revokedKeyID) + + // Verify signature + sig, err := ParseSignature(newSigData) + require.NoError(t, err) + assert.NotEmpty(t, sig.Signature) +} + +func TestExtendRevocationList_MultipleKeys(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Create empty revocation list + rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + rl, err := ParseRevocationList(rlData) + require.NoError(t, err) + + // Add first key + key1Pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + key1ID := computeKeyID(key1Pub) + + rlData, _, err = ExtendRevocationList(rootKey, *rl, key1ID, defaultRevocationListExpiration) + require.NoError(t, err) + + rl, err = ParseRevocationList(rlData) + require.NoError(t, err) + assert.Len(t, rl.Revoked, 1) + + // Add second key + key2Pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + key2ID := computeKeyID(key2Pub) + + rlData, _, err = ExtendRevocationList(rootKey, *rl, key2ID, defaultRevocationListExpiration) + require.NoError(t, err) + + rl, err = ParseRevocationList(rlData) + require.NoError(t, err) + assert.Len(t, rl.Revoked, 2) + assert.Contains(t, rl.Revoked, key1ID) + assert.Contains(t, rl.Revoked, key2ID) +} + +func TestExtendRevocationList_DuplicateKey(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Create empty revocation list + rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + rl, err := ParseRevocationList(rlData) + require.NoError(t, err) + + // Add a key + keyPub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + keyID := computeKeyID(keyPub) + + rlData, _, err = ExtendRevocationList(rootKey, *rl, keyID, defaultRevocationListExpiration) + require.NoError(t, err) + + rl, err = ParseRevocationList(rlData) + require.NoError(t, err) + firstRevocationTime := rl.Revoked[keyID] + + // Wait a bit + time.Sleep(10 * time.Millisecond) + + // Add the same key again + rlData, _, err = ExtendRevocationList(rootKey, *rl, keyID, defaultRevocationListExpiration) + require.NoError(t, err) + + rl, err = ParseRevocationList(rlData) + require.NoError(t, err) + assert.Len(t, rl.Revoked, 1) + + // The revocation time should be updated + secondRevocationTime := rl.Revoked[keyID] + assert.True(t, secondRevocationTime.After(firstRevocationTime) || secondRevocationTime.Equal(firstRevocationTime)) +} + +func TestExtendRevocationList_UpdatesLastUpdated(t *testing.T) { + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Create revocation list + rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + rl, err := ParseRevocationList(rlData) + require.NoError(t, err) + firstLastUpdated := rl.LastUpdated + + // Wait a bit + time.Sleep(10 * time.Millisecond) + + // Extend list + keyPub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + keyID := computeKeyID(keyPub) + + rlData, _, err = ExtendRevocationList(rootKey, *rl, keyID, defaultRevocationListExpiration) + require.NoError(t, err) + + rl, err = ParseRevocationList(rlData) + require.NoError(t, err) + + // LastUpdated should be updated + assert.True(t, rl.LastUpdated.After(firstLastUpdated)) +} + +// Integration test + +func TestRevocationList_FullWorkflow(t *testing.T) { + // Create root key + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + rootKey := RootKey{ + PrivateKey{ + Key: rootPriv, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + rootKeys := []PublicKey{ + { + Key: rootPub, + Metadata: KeyMetadata{ + ID: computeKeyID(rootPub), + CreatedAt: time.Now().UTC(), + }, + }, + } + + // Step 1: Create empty revocation list + rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + // Step 2: Validate it + sig, err := ParseSignature(sigData) + require.NoError(t, err) + + rl, err := ValidateRevocationList(rootKeys, rlData, *sig) + require.NoError(t, err) + assert.Empty(t, rl.Revoked) + + // Step 3: Revoke a key + revokedPub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + revokedKeyID := computeKeyID(revokedPub) + + rlData, sigData, err = ExtendRevocationList(rootKey, *rl, revokedKeyID, defaultRevocationListExpiration) + require.NoError(t, err) + + // Step 4: Validate the extended list + sig, err = ParseSignature(sigData) + require.NoError(t, err) + + rl, err = ValidateRevocationList(rootKeys, rlData, *sig) + require.NoError(t, err) + assert.Len(t, rl.Revoked, 1) + assert.Contains(t, rl.Revoked, revokedKeyID) + + // Step 5: Verify the revocation time is reasonable + revTime := rl.Revoked[revokedKeyID] + now := time.Now().UTC() + assert.True(t, revTime.Before(now) || revTime.Equal(now)) + assert.True(t, now.Sub(revTime) < time.Minute) +} diff --git a/client/internal/updatemanager/reposign/root.go b/client/internal/updatemanager/reposign/root.go new file mode 100644 index 000000000..2c3ca54a0 --- /dev/null +++ b/client/internal/updatemanager/reposign/root.go @@ -0,0 +1,120 @@ +package reposign + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/binary" + "encoding/json" + "encoding/pem" + "fmt" + "time" +) + +const ( + tagRootPrivate = "ROOT PRIVATE KEY" + tagRootPublic = "ROOT PUBLIC KEY" +) + +// RootKey is a root Key used to sign signing keys +type RootKey struct { + PrivateKey +} + +func (k RootKey) String() string { + return fmt.Sprintf( + "RootKey[ID=%s, CreatedAt=%s, ExpiresAt=%s]", + k.Metadata.ID, + k.Metadata.CreatedAt.Format(time.RFC3339), + k.Metadata.ExpiresAt.Format(time.RFC3339), + ) +} + +func ParseRootKey(privKeyPEM []byte) (*RootKey, error) { + pk, err := parsePrivateKey(privKeyPEM, tagRootPrivate) + if err != nil { + return nil, fmt.Errorf("failed to parse root Key: %w", err) + } + return &RootKey{pk}, nil +} + +// ParseRootPublicKey parses a root public key from PEM format +func ParseRootPublicKey(pubKeyPEM []byte) (PublicKey, error) { + pk, _, err := parsePublicKey(pubKeyPEM, tagRootPublic) + if err != nil { + return PublicKey{}, fmt.Errorf("failed to parse root public key: %w", err) + } + return pk, nil +} + +// GenerateRootKey generates a new root Key pair with Metadata +func GenerateRootKey(expiration time.Duration) (*RootKey, []byte, []byte, error) { + now := time.Now() + expirationTime := now.Add(expiration) + pub, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, nil, err + } + + metadata := KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: now.UTC(), + ExpiresAt: expirationTime.UTC(), + } + + rk := &RootKey{ + PrivateKey{ + Key: priv, + Metadata: metadata, + }, + } + + // Marshal PrivateKey struct to JSON + privJSON, err := json.Marshal(rk.PrivateKey) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to marshal private key: %w", err) + } + + // Marshal PublicKey struct to JSON + pubKey := PublicKey{ + Key: pub, + Metadata: metadata, + } + pubJSON, err := json.Marshal(pubKey) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to marshal public key: %w", err) + } + + // Encode to PEM with metadata embedded in bytes + privPEM := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPrivate, + Bytes: privJSON, + }) + + pubPEM := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPublic, + Bytes: pubJSON, + }) + + return rk, privPEM, pubPEM, nil +} + +func SignArtifactKey(rootKey RootKey, data []byte) ([]byte, error) { + timestamp := time.Now().UTC() + + // This ensures the timestamp is cryptographically bound to the signature + msg := make([]byte, 0, len(data)+8) + msg = append(msg, data...) + msg = binary.LittleEndian.AppendUint64(msg, uint64(timestamp.Unix())) + + sig := ed25519.Sign(rootKey.Key, msg) + // Create signature bundle with timestamp and Metadata + bundle := Signature{ + Signature: sig, + Timestamp: timestamp, + KeyID: rootKey.Metadata.ID, + Algorithm: "ed25519", + HashAlgo: "sha512", + } + + return json.Marshal(bundle) +} diff --git a/client/internal/updatemanager/reposign/root_test.go b/client/internal/updatemanager/reposign/root_test.go new file mode 100644 index 000000000..e75e29729 --- /dev/null +++ b/client/internal/updatemanager/reposign/root_test.go @@ -0,0 +1,476 @@ +package reposign + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/binary" + "encoding/json" + "encoding/pem" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test RootKey.String() + +func TestRootKey_String(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + createdAt := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + expiresAt := time.Date(2034, 1, 15, 10, 30, 0, 0, time.UTC) + + rk := RootKey{ + PrivateKey{ + Key: priv, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: createdAt, + ExpiresAt: expiresAt, + }, + }, + } + + str := rk.String() + assert.Contains(t, str, "RootKey") + assert.Contains(t, str, computeKeyID(pub).String()) + assert.Contains(t, str, "2024-01-15") + assert.Contains(t, str, "2034-01-15") +} + +func TestRootKey_String_NoExpiration(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + createdAt := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + + rk := RootKey{ + PrivateKey{ + Key: priv, + Metadata: KeyMetadata{ + ID: computeKeyID(pub), + CreatedAt: createdAt, + ExpiresAt: time.Time{}, // No expiration + }, + }, + } + + str := rk.String() + assert.Contains(t, str, "RootKey") + assert.Contains(t, str, "0001-01-01") // Zero time format +} + +// Test GenerateRootKey + +func TestGenerateRootKey_Valid(t *testing.T) { + expiration := 10 * 365 * 24 * time.Hour // 10 years + + rk, privPEM, pubPEM, err := GenerateRootKey(expiration) + require.NoError(t, err) + assert.NotNil(t, rk) + assert.NotEmpty(t, privPEM) + assert.NotEmpty(t, pubPEM) + + // Verify the key has correct metadata + assert.False(t, rk.Metadata.CreatedAt.IsZero()) + assert.False(t, rk.Metadata.ExpiresAt.IsZero()) + assert.True(t, rk.Metadata.ExpiresAt.After(rk.Metadata.CreatedAt)) + + // Verify expiration is approximately correct + expectedExpiration := time.Now().Add(expiration) + timeDiff := rk.Metadata.ExpiresAt.Sub(expectedExpiration) + assert.True(t, timeDiff < time.Minute && timeDiff > -time.Minute) +} + +func TestGenerateRootKey_ShortExpiration(t *testing.T) { + expiration := 24 * time.Hour // 1 day + + rk, _, _, err := GenerateRootKey(expiration) + require.NoError(t, err) + assert.NotNil(t, rk) + + // Verify expiration + expectedExpiration := time.Now().Add(expiration) + timeDiff := rk.Metadata.ExpiresAt.Sub(expectedExpiration) + assert.True(t, timeDiff < time.Minute && timeDiff > -time.Minute) +} + +func TestGenerateRootKey_ZeroExpiration(t *testing.T) { + rk, _, _, err := GenerateRootKey(0) + require.NoError(t, err) + assert.NotNil(t, rk) + + // With zero expiration, ExpiresAt should be equal to CreatedAt + assert.Equal(t, rk.Metadata.CreatedAt, rk.Metadata.ExpiresAt) +} + +func TestGenerateRootKey_PEMFormat(t *testing.T) { + rk, privPEM, pubPEM, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + // Verify private key PEM + privBlock, _ := pem.Decode(privPEM) + require.NotNil(t, privBlock) + assert.Equal(t, tagRootPrivate, privBlock.Type) + + var privKey PrivateKey + err = json.Unmarshal(privBlock.Bytes, &privKey) + require.NoError(t, err) + assert.Equal(t, rk.Key, privKey.Key) + + // Verify public key PEM + pubBlock, _ := pem.Decode(pubPEM) + require.NotNil(t, pubBlock) + assert.Equal(t, tagRootPublic, pubBlock.Type) + + var pubKey PublicKey + err = json.Unmarshal(pubBlock.Bytes, &pubKey) + require.NoError(t, err) + assert.Equal(t, rk.Metadata.ID, pubKey.Metadata.ID) +} + +func TestGenerateRootKey_KeySize(t *testing.T) { + rk, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + // Ed25519 private key should be 64 bytes + assert.Equal(t, ed25519.PrivateKeySize, len(rk.Key)) + + // Ed25519 public key should be 32 bytes + pubKey := rk.Key.Public().(ed25519.PublicKey) + assert.Equal(t, ed25519.PublicKeySize, len(pubKey)) +} + +func TestGenerateRootKey_UniqueKeys(t *testing.T) { + rk1, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + rk2, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + // Different keys should have different IDs + assert.NotEqual(t, rk1.Metadata.ID, rk2.Metadata.ID) + assert.NotEqual(t, rk1.Key, rk2.Key) +} + +// Test ParseRootKey + +func TestParseRootKey_Valid(t *testing.T) { + original, privPEM, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + parsed, err := ParseRootKey(privPEM) + require.NoError(t, err) + assert.NotNil(t, parsed) + + // Verify the parsed key matches the original + assert.Equal(t, original.Key, parsed.Key) + assert.Equal(t, original.Metadata.ID, parsed.Metadata.ID) + assert.Equal(t, original.Metadata.CreatedAt.Unix(), parsed.Metadata.CreatedAt.Unix()) + assert.Equal(t, original.Metadata.ExpiresAt.Unix(), parsed.Metadata.ExpiresAt.Unix()) +} + +func TestParseRootKey_InvalidPEM(t *testing.T) { + _, err := ParseRootKey([]byte("not a valid PEM")) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse") +} + +func TestParseRootKey_EmptyData(t *testing.T) { + _, err := ParseRootKey([]byte{}) + assert.Error(t, err) +} + +func TestParseRootKey_WrongType(t *testing.T) { + // Generate an artifact key instead of root key + rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + artifactKey, privPEM, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + // Try to parse artifact key as root key + _, err = ParseRootKey(privPEM) + assert.Error(t, err) + assert.Contains(t, err.Error(), "PEM type") + + // Just to use artifactKey to avoid unused variable warning + _ = artifactKey +} + +func TestParseRootKey_CorruptedJSON(t *testing.T) { + // Create PEM with corrupted JSON + corruptedPEM := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPrivate, + Bytes: []byte("corrupted json data"), + }) + + _, err := ParseRootKey(corruptedPEM) + assert.Error(t, err) +} + +func TestParseRootKey_InvalidKeySize(t *testing.T) { + // Create a key with invalid size + invalidKey := PrivateKey{ + Key: []byte{0x01, 0x02, 0x03}, // Too short + Metadata: KeyMetadata{ + ID: KeyID{}, + CreatedAt: time.Now().UTC(), + }, + } + + privJSON, err := json.Marshal(invalidKey) + require.NoError(t, err) + + invalidPEM := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPrivate, + Bytes: privJSON, + }) + + _, err = ParseRootKey(invalidPEM) + assert.Error(t, err) + assert.Contains(t, err.Error(), "incorrect Ed25519 private key size") +} + +func TestParseRootKey_Roundtrip(t *testing.T) { + // Generate a key + original, privPEM, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + // Parse it + parsed, err := ParseRootKey(privPEM) + require.NoError(t, err) + + // Generate PEM again from parsed key + privJSON2, err := json.Marshal(parsed.PrivateKey) + require.NoError(t, err) + + privPEM2 := pem.EncodeToMemory(&pem.Block{ + Type: tagRootPrivate, + Bytes: privJSON2, + }) + + // Parse again + parsed2, err := ParseRootKey(privPEM2) + require.NoError(t, err) + + // Should still match original + assert.Equal(t, original.Key, parsed2.Key) + assert.Equal(t, original.Metadata.ID, parsed2.Metadata.ID) +} + +// Test SignArtifactKey + +func TestSignArtifactKey_Valid(t *testing.T) { + rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + data := []byte("test data to sign") + sigData, err := SignArtifactKey(*rootKey, data) + require.NoError(t, err) + assert.NotEmpty(t, sigData) + + // Parse and verify signature + sig, err := ParseSignature(sigData) + require.NoError(t, err) + assert.NotEmpty(t, sig.Signature) + assert.Equal(t, rootKey.Metadata.ID, sig.KeyID) + assert.Equal(t, "ed25519", sig.Algorithm) + assert.Equal(t, "sha512", sig.HashAlgo) + assert.False(t, sig.Timestamp.IsZero()) +} + +func TestSignArtifactKey_EmptyData(t *testing.T) { + rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + sigData, err := SignArtifactKey(*rootKey, []byte{}) + require.NoError(t, err) + assert.NotEmpty(t, sigData) + + // Should still be able to parse + sig, err := ParseSignature(sigData) + require.NoError(t, err) + assert.NotEmpty(t, sig.Signature) +} + +func TestSignArtifactKey_Verify(t *testing.T) { + rootKey, _, pubPEM, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + // Parse public key + pubKey, _, err := parsePublicKey(pubPEM, tagRootPublic) + require.NoError(t, err) + + // Sign some data + data := []byte("test data for verification") + sigData, err := SignArtifactKey(*rootKey, data) + require.NoError(t, err) + + // Parse signature + sig, err := ParseSignature(sigData) + require.NoError(t, err) + + // Reconstruct message + msg := make([]byte, 0, len(data)+8) + msg = append(msg, data...) + msg = binary.LittleEndian.AppendUint64(msg, uint64(sig.Timestamp.Unix())) + + // Verify signature + valid := ed25519.Verify(pubKey.Key, msg, sig.Signature) + assert.True(t, valid) +} + +func TestSignArtifactKey_DifferentData(t *testing.T) { + rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + data1 := []byte("data1") + data2 := []byte("data2") + + sig1, err := SignArtifactKey(*rootKey, data1) + require.NoError(t, err) + + sig2, err := SignArtifactKey(*rootKey, data2) + require.NoError(t, err) + + // Different data should produce different signatures + assert.NotEqual(t, sig1, sig2) +} + +func TestSignArtifactKey_MultipleSignatures(t *testing.T) { + rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + data := []byte("test data") + + // Sign twice with a small delay + sig1, err := SignArtifactKey(*rootKey, data) + require.NoError(t, err) + + time.Sleep(10 * time.Millisecond) + + sig2, err := SignArtifactKey(*rootKey, data) + require.NoError(t, err) + + // Signatures should be different due to different timestamps + assert.NotEqual(t, sig1, sig2) + + // Parse both signatures + parsed1, err := ParseSignature(sig1) + require.NoError(t, err) + + parsed2, err := ParseSignature(sig2) + require.NoError(t, err) + + // Timestamps should be different + assert.True(t, parsed2.Timestamp.After(parsed1.Timestamp)) +} + +func TestSignArtifactKey_LargeData(t *testing.T) { + rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + // Create 1MB of data + largeData := make([]byte, 1024*1024) + for i := range largeData { + largeData[i] = byte(i % 256) + } + + sigData, err := SignArtifactKey(*rootKey, largeData) + require.NoError(t, err) + assert.NotEmpty(t, sigData) + + // Verify signature can be parsed + sig, err := ParseSignature(sigData) + require.NoError(t, err) + assert.NotEmpty(t, sig.Signature) +} + +func TestSignArtifactKey_TimestampInSignature(t *testing.T) { + rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + beforeSign := time.Now().UTC() + data := []byte("test data") + sigData, err := SignArtifactKey(*rootKey, data) + require.NoError(t, err) + afterSign := time.Now().UTC() + + sig, err := ParseSignature(sigData) + require.NoError(t, err) + + // Timestamp should be between before and after + assert.True(t, sig.Timestamp.After(beforeSign.Add(-time.Second))) + assert.True(t, sig.Timestamp.Before(afterSign.Add(time.Second))) +} + +// Integration test + +func TestRootKey_FullWorkflow(t *testing.T) { + // Step 1: Generate root key + rootKey, privPEM, pubPEM, err := GenerateRootKey(10 * 365 * 24 * time.Hour) + require.NoError(t, err) + assert.NotNil(t, rootKey) + assert.NotEmpty(t, privPEM) + assert.NotEmpty(t, pubPEM) + + // Step 2: Parse the private key back + parsedRootKey, err := ParseRootKey(privPEM) + require.NoError(t, err) + assert.Equal(t, rootKey.Key, parsedRootKey.Key) + assert.Equal(t, rootKey.Metadata.ID, parsedRootKey.Metadata.ID) + + // Step 3: Generate an artifact key using root key + artifactKey, _, artifactPubPEM, artifactSig, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + assert.NotNil(t, artifactKey) + + // Step 4: Verify the artifact key signature + pubKey, _, err := parsePublicKey(pubPEM, tagRootPublic) + require.NoError(t, err) + + sig, err := ParseSignature(artifactSig) + require.NoError(t, err) + + artifactPubKey, _, err := parsePublicKey(artifactPubPEM, tagArtifactPublic) + require.NoError(t, err) + + // Reconstruct message - SignArtifactKey signs the PEM, not the JSON + msg := make([]byte, 0, len(artifactPubPEM)+8) + msg = append(msg, artifactPubPEM...) + msg = binary.LittleEndian.AppendUint64(msg, uint64(sig.Timestamp.Unix())) + + // Verify with root public key + valid := ed25519.Verify(pubKey.Key, msg, sig.Signature) + assert.True(t, valid, "Artifact key signature should be valid") + + // Step 5: Use artifact key to sign data + testData := []byte("This is test artifact data") + dataSig, err := SignData(*artifactKey, testData) + require.NoError(t, err) + assert.NotEmpty(t, dataSig) + + // Step 6: Verify the artifact data signature + dataSigParsed, err := ParseSignature(dataSig) + require.NoError(t, err) + + err = ValidateArtifact([]PublicKey{artifactPubKey}, testData, *dataSigParsed) + assert.NoError(t, err, "Artifact data signature should be valid") +} + +func TestRootKey_ExpiredKeyWorkflow(t *testing.T) { + // Generate a root key that expires very soon + rootKey, _, _, err := GenerateRootKey(1 * time.Millisecond) + require.NoError(t, err) + + // Wait for expiration + time.Sleep(10 * time.Millisecond) + + // Try to generate artifact key with expired root key + _, _, _, _, err = GenerateArtifactKey(rootKey, 30*24*time.Hour) + assert.Error(t, err) + assert.Contains(t, err.Error(), "expired") +} diff --git a/client/internal/updatemanager/reposign/signature.go b/client/internal/updatemanager/reposign/signature.go new file mode 100644 index 000000000..c7f06e94e --- /dev/null +++ b/client/internal/updatemanager/reposign/signature.go @@ -0,0 +1,24 @@ +package reposign + +import ( + "encoding/json" + "time" +) + +// Signature contains a signature with associated Metadata +type Signature struct { + Signature []byte `json:"signature"` + Timestamp time.Time `json:"timestamp"` + KeyID KeyID `json:"key_id"` + Algorithm string `json:"algorithm"` // "ed25519" + HashAlgo string `json:"hash_algo"` // "blake2s" or sha512 +} + +func ParseSignature(data []byte) (*Signature, error) { + var signature Signature + if err := json.Unmarshal(data, &signature); err != nil { + return nil, err + } + + return &signature, nil +} diff --git a/client/internal/updatemanager/reposign/signature_test.go b/client/internal/updatemanager/reposign/signature_test.go new file mode 100644 index 000000000..1960c5518 --- /dev/null +++ b/client/internal/updatemanager/reposign/signature_test.go @@ -0,0 +1,277 @@ +package reposign + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseSignature_Valid(t *testing.T) { + timestamp := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + keyID, err := ParseKeyID("0123456789abcdef") + require.NoError(t, err) + + signatureData := []byte{0x01, 0x02, 0x03, 0x04} + + jsonData, err := json.Marshal(Signature{ + Signature: signatureData, + Timestamp: timestamp, + KeyID: keyID, + Algorithm: "ed25519", + HashAlgo: "blake2s", + }) + require.NoError(t, err) + + sig, err := ParseSignature(jsonData) + require.NoError(t, err) + assert.NotNil(t, sig) + assert.Equal(t, signatureData, sig.Signature) + assert.Equal(t, timestamp.Unix(), sig.Timestamp.Unix()) + assert.Equal(t, keyID, sig.KeyID) + assert.Equal(t, "ed25519", sig.Algorithm) + assert.Equal(t, "blake2s", sig.HashAlgo) +} + +func TestParseSignature_InvalidJSON(t *testing.T) { + invalidJSON := []byte(`{invalid json}`) + + sig, err := ParseSignature(invalidJSON) + assert.Error(t, err) + assert.Nil(t, sig) +} + +func TestParseSignature_EmptyData(t *testing.T) { + emptyJSON := []byte(`{}`) + + sig, err := ParseSignature(emptyJSON) + require.NoError(t, err) + assert.NotNil(t, sig) + assert.Empty(t, sig.Signature) + assert.True(t, sig.Timestamp.IsZero()) + assert.Equal(t, KeyID{}, sig.KeyID) + assert.Empty(t, sig.Algorithm) + assert.Empty(t, sig.HashAlgo) +} + +func TestParseSignature_MissingFields(t *testing.T) { + // JSON with only some fields + partialJSON := []byte(`{ + "signature": "AQIDBA==", + "algorithm": "ed25519" + }`) + + sig, err := ParseSignature(partialJSON) + require.NoError(t, err) + assert.NotNil(t, sig) + assert.NotEmpty(t, sig.Signature) + assert.Equal(t, "ed25519", sig.Algorithm) + assert.True(t, sig.Timestamp.IsZero()) +} + +func TestSignature_MarshalUnmarshal_Roundtrip(t *testing.T) { + timestamp := time.Date(2024, 6, 20, 14, 45, 30, 0, time.UTC) + keyID, err := ParseKeyID("fedcba9876543210") + require.NoError(t, err) + + original := Signature{ + Signature: []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe}, + Timestamp: timestamp, + KeyID: keyID, + Algorithm: "ed25519", + HashAlgo: "sha512", + } + + // Marshal + jsonData, err := json.Marshal(original) + require.NoError(t, err) + + // Unmarshal + parsed, err := ParseSignature(jsonData) + require.NoError(t, err) + + // Verify + assert.Equal(t, original.Signature, parsed.Signature) + assert.Equal(t, original.Timestamp.Unix(), parsed.Timestamp.Unix()) + assert.Equal(t, original.KeyID, parsed.KeyID) + assert.Equal(t, original.Algorithm, parsed.Algorithm) + assert.Equal(t, original.HashAlgo, parsed.HashAlgo) +} + +func TestSignature_NilSignatureBytes(t *testing.T) { + timestamp := time.Now().UTC() + keyID, err := ParseKeyID("0011223344556677") + require.NoError(t, err) + + sig := Signature{ + Signature: nil, + Timestamp: timestamp, + KeyID: keyID, + Algorithm: "ed25519", + HashAlgo: "blake2s", + } + + jsonData, err := json.Marshal(sig) + require.NoError(t, err) + + parsed, err := ParseSignature(jsonData) + require.NoError(t, err) + assert.Nil(t, parsed.Signature) +} + +func TestSignature_LargeSignature(t *testing.T) { + timestamp := time.Now().UTC() + keyID, err := ParseKeyID("aabbccddeeff0011") + require.NoError(t, err) + + // Create a large signature (64 bytes for ed25519) + largeSignature := make([]byte, 64) + for i := range largeSignature { + largeSignature[i] = byte(i) + } + + sig := Signature{ + Signature: largeSignature, + Timestamp: timestamp, + KeyID: keyID, + Algorithm: "ed25519", + HashAlgo: "blake2s", + } + + jsonData, err := json.Marshal(sig) + require.NoError(t, err) + + parsed, err := ParseSignature(jsonData) + require.NoError(t, err) + assert.Equal(t, largeSignature, parsed.Signature) +} + +func TestSignature_WithDifferentHashAlgorithms(t *testing.T) { + tests := []struct { + name string + hashAlgo string + }{ + {"blake2s", "blake2s"}, + {"sha512", "sha512"}, + {"sha256", "sha256"}, + {"empty", ""}, + } + + keyID, err := ParseKeyID("1122334455667788") + require.NoError(t, err) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sig := Signature{ + Signature: []byte{0x01, 0x02}, + Timestamp: time.Now().UTC(), + KeyID: keyID, + Algorithm: "ed25519", + HashAlgo: tt.hashAlgo, + } + + jsonData, err := json.Marshal(sig) + require.NoError(t, err) + + parsed, err := ParseSignature(jsonData) + require.NoError(t, err) + assert.Equal(t, tt.hashAlgo, parsed.HashAlgo) + }) + } +} + +func TestSignature_TimestampPrecision(t *testing.T) { + // Test that timestamp preserves precision through JSON marshaling + timestamp := time.Date(2024, 3, 15, 10, 30, 45, 123456789, time.UTC) + keyID, err := ParseKeyID("8877665544332211") + require.NoError(t, err) + + sig := Signature{ + Signature: []byte{0xaa, 0xbb}, + Timestamp: timestamp, + KeyID: keyID, + Algorithm: "ed25519", + HashAlgo: "blake2s", + } + + jsonData, err := json.Marshal(sig) + require.NoError(t, err) + + parsed, err := ParseSignature(jsonData) + require.NoError(t, err) + + // JSON timestamps typically have second or millisecond precision + // so we check that at least seconds match + assert.Equal(t, timestamp.Unix(), parsed.Timestamp.Unix()) +} + +func TestParseSignature_MalformedKeyID(t *testing.T) { + // Test with a malformed KeyID field + malformedJSON := []byte(`{ + "signature": "AQID", + "timestamp": "2024-01-15T10:30:00Z", + "key_id": "invalid_keyid_format", + "algorithm": "ed25519", + "hash_algo": "blake2s" + }`) + + // This should fail since "invalid_keyid_format" is not a valid KeyID + sig, err := ParseSignature(malformedJSON) + assert.Error(t, err) + assert.Nil(t, sig) +} + +func TestParseSignature_InvalidTimestamp(t *testing.T) { + // Test with an invalid timestamp format + invalidTimestampJSON := []byte(`{ + "signature": "AQID", + "timestamp": "not-a-timestamp", + "key_id": "0123456789abcdef", + "algorithm": "ed25519", + "hash_algo": "blake2s" + }`) + + sig, err := ParseSignature(invalidTimestampJSON) + assert.Error(t, err) + assert.Nil(t, sig) +} + +func TestSignature_ZeroKeyID(t *testing.T) { + // Test with a zero KeyID + sig := Signature{ + Signature: []byte{0x01, 0x02, 0x03}, + Timestamp: time.Now().UTC(), + KeyID: KeyID{}, + Algorithm: "ed25519", + HashAlgo: "blake2s", + } + + jsonData, err := json.Marshal(sig) + require.NoError(t, err) + + parsed, err := ParseSignature(jsonData) + require.NoError(t, err) + assert.Equal(t, KeyID{}, parsed.KeyID) +} + +func TestParseSignature_ExtraFields(t *testing.T) { + // JSON with extra fields that should be ignored + jsonWithExtra := []byte(`{ + "signature": "AQIDBA==", + "timestamp": "2024-01-15T10:30:00Z", + "key_id": "0123456789abcdef", + "algorithm": "ed25519", + "hash_algo": "blake2s", + "extra_field": "should be ignored", + "another_extra": 12345 + }`) + + sig, err := ParseSignature(jsonWithExtra) + require.NoError(t, err) + assert.NotNil(t, sig) + assert.NotEmpty(t, sig.Signature) + assert.Equal(t, "ed25519", sig.Algorithm) + assert.Equal(t, "blake2s", sig.HashAlgo) +} diff --git a/client/internal/updatemanager/reposign/verify.go b/client/internal/updatemanager/reposign/verify.go new file mode 100644 index 000000000..0af2a8c9e --- /dev/null +++ b/client/internal/updatemanager/reposign/verify.go @@ -0,0 +1,187 @@ +package reposign + +import ( + "context" + "fmt" + "net/url" + "os" + "path/filepath" + "strings" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/updatemanager/downloader" +) + +const ( + artifactPubKeysFileName = "artifact-key-pub.pem" + artifactPubKeysSigFileName = "artifact-key-pub.pem.sig" + revocationFileName = "revocation-list.json" + revocationSignFileName = "revocation-list.json.sig" + + keySizeLimit = 5 * 1024 * 1024 //5MB + signatureLimit = 1024 + revocationLimit = 10 * 1024 * 1024 +) + +type ArtifactVerify struct { + rootKeys []PublicKey + keysBaseURL *url.URL + + revocationList *RevocationList +} + +func NewArtifactVerify(keysBaseURL string) (*ArtifactVerify, error) { + allKeys, err := loadEmbeddedPublicKeys() + if err != nil { + return nil, err + } + + return newArtifactVerify(keysBaseURL, allKeys) +} + +func newArtifactVerify(keysBaseURL string, allKeys []PublicKey) (*ArtifactVerify, error) { + ku, err := url.Parse(keysBaseURL) + if err != nil { + return nil, fmt.Errorf("invalid keys base URL %q: %v", keysBaseURL, err) + } + + a := &ArtifactVerify{ + rootKeys: allKeys, + keysBaseURL: ku, + } + return a, nil +} + +func (a *ArtifactVerify) Verify(ctx context.Context, version string, artifactFile string) error { + version = strings.TrimPrefix(version, "v") + + revocationList, err := a.loadRevocationList(ctx) + if err != nil { + return fmt.Errorf("failed to load revocation list: %v", err) + } + a.revocationList = revocationList + + artifactPubKeys, err := a.loadArtifactKeys(ctx) + if err != nil { + return fmt.Errorf("failed to load artifact keys: %v", err) + } + + signature, err := a.loadArtifactSignature(ctx, version, artifactFile) + if err != nil { + return fmt.Errorf("failed to download signature file for: %s, %v", filepath.Base(artifactFile), err) + } + + artifactData, err := os.ReadFile(artifactFile) + if err != nil { + log.Errorf("failed to read artifact file: %v", err) + return fmt.Errorf("failed to read artifact file: %w", err) + } + + if err := ValidateArtifact(artifactPubKeys, artifactData, *signature); err != nil { + return fmt.Errorf("failed to validate artifact: %v", err) + } + + return nil +} + +func (a *ArtifactVerify) loadRevocationList(ctx context.Context) (*RevocationList, error) { + downloadURL := a.keysBaseURL.JoinPath("keys", revocationFileName).String() + data, err := downloader.DownloadToMemory(ctx, downloadURL, revocationLimit) + if err != nil { + log.Debugf("failed to download revocation list '%s': %s", downloadURL, err) + return nil, err + } + + downloadURL = a.keysBaseURL.JoinPath("keys", revocationSignFileName).String() + sigData, err := downloader.DownloadToMemory(ctx, downloadURL, signatureLimit) + if err != nil { + log.Debugf("failed to download revocation list '%s': %s", downloadURL, err) + return nil, err + } + + signature, err := ParseSignature(sigData) + if err != nil { + log.Debugf("failed to parse revocation list signature: %s", err) + return nil, err + } + + return ValidateRevocationList(a.rootKeys, data, *signature) +} + +func (a *ArtifactVerify) loadArtifactKeys(ctx context.Context) ([]PublicKey, error) { + downloadURL := a.keysBaseURL.JoinPath("keys", artifactPubKeysFileName).String() + log.Debugf("starting downloading artifact keys from: %s", downloadURL) + data, err := downloader.DownloadToMemory(ctx, downloadURL, keySizeLimit) + if err != nil { + log.Debugf("failed to download artifact keys: %s", err) + return nil, err + } + + downloadURL = a.keysBaseURL.JoinPath("keys", artifactPubKeysSigFileName).String() + log.Debugf("start downloading signature of artifact pub key from: %s", downloadURL) + sigData, err := downloader.DownloadToMemory(ctx, downloadURL, signatureLimit) + if err != nil { + log.Debugf("failed to download signature of public keys: %s", err) + return nil, err + } + + signature, err := ParseSignature(sigData) + if err != nil { + log.Debugf("failed to parse signature of public keys: %s", err) + return nil, err + } + + return ValidateArtifactKeys(a.rootKeys, data, *signature, a.revocationList) +} + +func (a *ArtifactVerify) loadArtifactSignature(ctx context.Context, version string, artifactFile string) (*Signature, error) { + artifactFile = filepath.Base(artifactFile) + downloadURL := a.keysBaseURL.JoinPath("tag", "v"+version, artifactFile+".sig").String() + data, err := downloader.DownloadToMemory(ctx, downloadURL, signatureLimit) + if err != nil { + log.Debugf("failed to download artifact signature: %s", err) + return nil, err + } + + signature, err := ParseSignature(data) + if err != nil { + log.Debugf("failed to parse artifact signature: %s", err) + return nil, err + } + + return signature, nil + +} + +func loadEmbeddedPublicKeys() ([]PublicKey, error) { + files, err := embeddedCerts.ReadDir(embeddedCertsDir) + if err != nil { + return nil, fmt.Errorf("failed to read embedded certs: %w", err) + } + + var allKeys []PublicKey + for _, file := range files { + if file.IsDir() { + continue + } + + data, err := embeddedCerts.ReadFile(embeddedCertsDir + "/" + file.Name()) + if err != nil { + return nil, fmt.Errorf("failed to read cert file %s: %w", file.Name(), err) + } + + keys, err := parsePublicKeyBundle(data, tagRootPublic) + if err != nil { + return nil, fmt.Errorf("failed to parse cert %s: %w", file.Name(), err) + } + + allKeys = append(allKeys, keys...) + } + + if len(allKeys) == 0 { + return nil, fmt.Errorf("no valid public keys found in embedded certs") + } + + return allKeys, nil +} diff --git a/client/internal/updatemanager/reposign/verify_test.go b/client/internal/updatemanager/reposign/verify_test.go new file mode 100644 index 000000000..c29393bad --- /dev/null +++ b/client/internal/updatemanager/reposign/verify_test.go @@ -0,0 +1,528 @@ +package reposign + +import ( + "context" + "crypto/ed25519" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test ArtifactVerify construction + +func TestArtifactVerify_Construction(t *testing.T) { + // Generate test root key + rootKey, _, rootPubPEM, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + rootPubKey, _, err := parsePublicKey(rootPubPEM, tagRootPublic) + require.NoError(t, err) + + keysBaseURL := "http://localhost:8080/artifact-signatures" + + av, err := newArtifactVerify(keysBaseURL, []PublicKey{rootPubKey}) + require.NoError(t, err) + + assert.NotNil(t, av) + assert.NotEmpty(t, av.rootKeys) + assert.Equal(t, keysBaseURL, av.keysBaseURL.String()) + + // Verify root key structure + assert.NotEmpty(t, av.rootKeys[0].Key) + assert.Equal(t, rootKey.Metadata.ID, av.rootKeys[0].Metadata.ID) + assert.False(t, av.rootKeys[0].Metadata.CreatedAt.IsZero()) +} + +func TestArtifactVerify_MultipleRootKeys(t *testing.T) { + // Generate multiple test root keys + rootKey1, _, rootPubPEM1, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + rootPubKey1, _, err := parsePublicKey(rootPubPEM1, tagRootPublic) + require.NoError(t, err) + + rootKey2, _, rootPubPEM2, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + rootPubKey2, _, err := parsePublicKey(rootPubPEM2, tagRootPublic) + require.NoError(t, err) + + keysBaseURL := "http://localhost:8080/artifact-signatures" + + av, err := newArtifactVerify(keysBaseURL, []PublicKey{rootPubKey1, rootPubKey2}) + assert.NoError(t, err) + assert.Len(t, av.rootKeys, 2) + assert.NotEqual(t, rootKey1.Metadata.ID, rootKey2.Metadata.ID) +} + +// Test Verify workflow with mock HTTP server + +func TestArtifactVerify_FullWorkflow(t *testing.T) { + // Create temporary test directory + tempDir := t.TempDir() + + // Step 1: Generate root key + rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour) + require.NoError(t, err) + + // Step 2: Generate artifact key + artifactKey, _, artifactPubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + artifactPubKey, err := ParseArtifactPubKey(artifactPubPEM) + require.NoError(t, err) + + // Step 3: Create revocation list + revocationData, revocationSig, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + // Step 4: Bundle artifact keys + artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey}) + require.NoError(t, err) + + // Step 5: Create test artifact + artifactPath := filepath.Join(tempDir, "test-artifact.bin") + artifactData := []byte("This is test artifact data for verification") + err = os.WriteFile(artifactPath, artifactData, 0644) + require.NoError(t, err) + + // Step 6: Sign artifact + artifactSigData, err := SignData(*artifactKey, artifactData) + require.NoError(t, err) + + // Step 7: Setup mock HTTP server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/artifact-signatures/keys/" + revocationFileName: + _, _ = w.Write(revocationData) + case "/artifact-signatures/keys/" + revocationSignFileName: + _, _ = w.Write(revocationSig) + case "/artifact-signatures/keys/" + artifactPubKeysFileName: + _, _ = w.Write(artifactKeysBundle) + case "/artifact-signatures/keys/" + artifactPubKeysSigFileName: + _, _ = w.Write(artifactKeysSig) + case "/artifacts/v1.0.0/test-artifact.bin": + _, _ = w.Write(artifactData) + case "/artifact-signatures/tag/v1.0.0/test-artifact.bin.sig": + _, _ = w.Write(artifactSigData) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + // Step 8: Create ArtifactVerify with test root key + rootPubKey := PublicKey{ + Key: rootKey.Key.Public().(ed25519.PublicKey), + Metadata: rootKey.Metadata, + } + + av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey}) + require.NoError(t, err) + + // Step 9: Verify artifact + ctx := context.Background() + err = av.Verify(ctx, "1.0.0", artifactPath) + assert.NoError(t, err) +} + +func TestArtifactVerify_InvalidRevocationList(t *testing.T) { + tempDir := t.TempDir() + artifactPath := filepath.Join(tempDir, "test.bin") + err := os.WriteFile(artifactPath, []byte("test"), 0644) + require.NoError(t, err) + + // Setup server with invalid revocation list + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/artifact-signatures/keys/" + revocationFileName: + _, _ = w.Write([]byte("invalid data")) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + rootPubKey := PublicKey{ + Key: rootKey.Key.Public().(ed25519.PublicKey), + Metadata: rootKey.Metadata, + } + + av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey}) + require.NoError(t, err) + + ctx := context.Background() + err = av.Verify(ctx, "1.0.0", artifactPath) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to load revocation list") +} + +func TestArtifactVerify_MissingArtifactFile(t *testing.T) { + rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + rootPubKey := PublicKey{ + Key: rootKey.Key.Public().(ed25519.PublicKey), + Metadata: rootKey.Metadata, + } + + // Create revocation list + revocationData, revocationSig, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + artifactKey, _, artifactPubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + + artifactPubKey, err := ParseArtifactPubKey(artifactPubPEM) + require.NoError(t, err) + + artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey}) + require.NoError(t, err) + + // Create signature for non-existent file + testData := []byte("test") + artifactSigData, err := SignData(*artifactKey, testData) + require.NoError(t, err) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/artifact-signatures/keys/" + revocationFileName: + _, _ = w.Write(revocationData) + case "/artifact-signatures/keys/" + revocationSignFileName: + _, _ = w.Write(revocationSig) + case "/artifact-signatures/keys/" + artifactPubKeysFileName: + _, _ = w.Write(artifactKeysBundle) + case "/artifact-signatures/keys/" + artifactPubKeysSigFileName: + _, _ = w.Write(artifactKeysSig) + case "/artifact-signatures/tag/v1.0.0/missing.bin.sig": + _, _ = w.Write(artifactSigData) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey}) + require.NoError(t, err) + + ctx := context.Background() + err = av.Verify(ctx, "1.0.0", "file.bin") + assert.Error(t, err) +} + +func TestArtifactVerify_ServerUnavailable(t *testing.T) { + tempDir := t.TempDir() + artifactPath := filepath.Join(tempDir, "test.bin") + err := os.WriteFile(artifactPath, []byte("test"), 0644) + require.NoError(t, err) + + rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + rootPubKey := PublicKey{ + Key: rootKey.Key.Public().(ed25519.PublicKey), + Metadata: rootKey.Metadata, + } + + // Use URL that doesn't exist + av, err := newArtifactVerify("http://localhost:19999/keys", []PublicKey{rootPubKey}) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + err = av.Verify(ctx, "1.0.0", artifactPath) + assert.Error(t, err) +} + +func TestArtifactVerify_ContextCancellation(t *testing.T) { + tempDir := t.TempDir() + artifactPath := filepath.Join(tempDir, "test.bin") + err := os.WriteFile(artifactPath, []byte("test"), 0644) + require.NoError(t, err) + + // Create a server that delays response + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(500 * time.Millisecond) + _, _ = w.Write([]byte("data")) + })) + defer server.Close() + + rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour) + require.NoError(t, err) + + rootPubKey := PublicKey{ + Key: rootKey.Key.Public().(ed25519.PublicKey), + Metadata: rootKey.Metadata, + } + + av, err := newArtifactVerify(server.URL, []PublicKey{rootPubKey}) + require.NoError(t, err) + + // Create context that cancels quickly + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + err = av.Verify(ctx, "1.0.0", artifactPath) + assert.Error(t, err) +} + +func TestArtifactVerify_WithRevocation(t *testing.T) { + tempDir := t.TempDir() + + // Generate root key + rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour) + require.NoError(t, err) + + // Generate two artifact keys + artifactKey1, _, artifactPubPEM1, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + artifactPubKey1, err := ParseArtifactPubKey(artifactPubPEM1) + require.NoError(t, err) + + _, _, artifactPubPEM2, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + artifactPubKey2, err := ParseArtifactPubKey(artifactPubPEM2) + require.NoError(t, err) + + // Create revocation list with first key revoked + emptyRevocation, _, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + parsedRevocation, err := ParseRevocationList(emptyRevocation) + require.NoError(t, err) + + revocationData, revocationSig, err := ExtendRevocationList(*rootKey, *parsedRevocation, artifactPubKey1.Metadata.ID, defaultRevocationListExpiration) + require.NoError(t, err) + + // Bundle both keys + artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey1, artifactPubKey2}) + require.NoError(t, err) + + // Create artifact signed by revoked key + artifactPath := filepath.Join(tempDir, "test.bin") + artifactData := []byte("test data") + err = os.WriteFile(artifactPath, artifactData, 0644) + require.NoError(t, err) + + artifactSigData, err := SignData(*artifactKey1, artifactData) + require.NoError(t, err) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/artifact-signatures/keys/" + revocationFileName: + _, _ = w.Write(revocationData) + case "/artifact-signatures/keys/" + revocationSignFileName: + _, _ = w.Write(revocationSig) + case "/artifact-signatures/keys/" + artifactPubKeysFileName: + _, _ = w.Write(artifactKeysBundle) + case "/artifact-signatures/keys/" + artifactPubKeysSigFileName: + _, _ = w.Write(artifactKeysSig) + case "/artifact-signatures/tag/v1.0.0/test.bin.sig": + _, _ = w.Write(artifactSigData) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + rootPubKey := PublicKey{ + Key: rootKey.Key.Public().(ed25519.PublicKey), + Metadata: rootKey.Metadata, + } + + av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey}) + require.NoError(t, err) + + ctx := context.Background() + err = av.Verify(ctx, "1.0.0", artifactPath) + // Should fail because the signing key is revoked + assert.Error(t, err) + assert.Contains(t, err.Error(), "no signing Key found") +} + +func TestArtifactVerify_ValidWithSecondKey(t *testing.T) { + tempDir := t.TempDir() + + // Generate root key + rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour) + require.NoError(t, err) + + // Generate two artifact keys + _, _, artifactPubPEM1, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + artifactPubKey1, err := ParseArtifactPubKey(artifactPubPEM1) + require.NoError(t, err) + + artifactKey2, _, artifactPubPEM2, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + artifactPubKey2, err := ParseArtifactPubKey(artifactPubPEM2) + require.NoError(t, err) + + // Create revocation list with first key revoked + emptyRevocation, _, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + parsedRevocation, err := ParseRevocationList(emptyRevocation) + require.NoError(t, err) + + revocationData, revocationSig, err := ExtendRevocationList(*rootKey, *parsedRevocation, artifactPubKey1.Metadata.ID, defaultRevocationListExpiration) + require.NoError(t, err) + + // Bundle both keys + artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey1, artifactPubKey2}) + require.NoError(t, err) + + // Create artifact signed by second key (not revoked) + artifactPath := filepath.Join(tempDir, "test.bin") + artifactData := []byte("test data") + err = os.WriteFile(artifactPath, artifactData, 0644) + require.NoError(t, err) + + artifactSigData, err := SignData(*artifactKey2, artifactData) + require.NoError(t, err) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/artifact-signatures/keys/" + revocationFileName: + _, _ = w.Write(revocationData) + case "/artifact-signatures/keys/" + revocationSignFileName: + _, _ = w.Write(revocationSig) + case "/artifact-signatures/keys/" + artifactPubKeysFileName: + _, _ = w.Write(artifactKeysBundle) + case "/artifact-signatures/keys/" + artifactPubKeysSigFileName: + _, _ = w.Write(artifactKeysSig) + case "/artifact-signatures/tag/v1.0.0/test.bin.sig": + _, _ = w.Write(artifactSigData) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + rootPubKey := PublicKey{ + Key: rootKey.Key.Public().(ed25519.PublicKey), + Metadata: rootKey.Metadata, + } + + av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey}) + require.NoError(t, err) + + ctx := context.Background() + err = av.Verify(ctx, "1.0.0", artifactPath) + // Should succeed because second key is not revoked + assert.NoError(t, err) +} + +func TestArtifactVerify_TamperedArtifact(t *testing.T) { + tempDir := t.TempDir() + + // Generate root key and artifact key + rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour) + require.NoError(t, err) + + artifactKey, _, artifactPubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour) + require.NoError(t, err) + artifactPubKey, err := ParseArtifactPubKey(artifactPubPEM) + require.NoError(t, err) + + // Create revocation list + revocationData, revocationSig, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration) + require.NoError(t, err) + + // Bundle keys + artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey}) + require.NoError(t, err) + + // Sign original data + originalData := []byte("original data") + artifactSigData, err := SignData(*artifactKey, originalData) + require.NoError(t, err) + + // Write tampered data to file + artifactPath := filepath.Join(tempDir, "test.bin") + tamperedData := []byte("tampered data") + err = os.WriteFile(artifactPath, tamperedData, 0644) + require.NoError(t, err) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/artifact-signatures/keys/" + revocationFileName: + _, _ = w.Write(revocationData) + case "/artifact-signatures/keys/" + revocationSignFileName: + _, _ = w.Write(revocationSig) + case "/artifact-signatures/keys/" + artifactPubKeysFileName: + _, _ = w.Write(artifactKeysBundle) + case "/artifact-signatures/keys/" + artifactPubKeysSigFileName: + _, _ = w.Write(artifactKeysSig) + case "/artifact-signatures/tag/v1.0.0/test.bin.sig": + _, _ = w.Write(artifactSigData) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + rootPubKey := PublicKey{ + Key: rootKey.Key.Public().(ed25519.PublicKey), + Metadata: rootKey.Metadata, + } + + av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey}) + require.NoError(t, err) + + ctx := context.Background() + err = av.Verify(ctx, "1.0.0", artifactPath) + // Should fail because artifact was tampered + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to validate artifact") +} + +// Test URL validation + +func TestArtifactVerify_URLParsing(t *testing.T) { + tests := []struct { + name string + keysBaseURL string + expectError bool + }{ + { + name: "Valid HTTP URL", + keysBaseURL: "http://example.com/artifact-signatures", + expectError: false, + }, + { + name: "Valid HTTPS URL", + keysBaseURL: "https://example.com/artifact-signatures", + expectError: false, + }, + { + name: "URL with port", + keysBaseURL: "http://localhost:8080/artifact-signatures", + expectError: false, + }, + { + name: "Invalid URL", + keysBaseURL: "://invalid", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := newArtifactVerify(tt.keysBaseURL, nil) + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/client/internal/updatemanager/update.go b/client/internal/updatemanager/update.go new file mode 100644 index 000000000..875b50b49 --- /dev/null +++ b/client/internal/updatemanager/update.go @@ -0,0 +1,11 @@ +package updatemanager + +import v "github.com/hashicorp/go-version" + +type UpdateInterface interface { + StopWatch() + SetDaemonVersion(newVersion string) bool + SetOnUpdateListener(updateFn func()) + LatestVersion() *v.Version + StartFetcher() +} diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index 463c93d57..e901386d9 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -75,6 +75,8 @@ type Client struct { dnsManager dns.IosDnsManager loginComplete bool connectClient *internal.ConnectClient + // preloadedConfig holds config loaded from JSON (used on tvOS where file writes are blocked) + preloadedConfig *profilemanager.Config } // NewClient instantiate a new Client @@ -92,17 +94,44 @@ func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName s } } +// SetConfigFromJSON loads config from a JSON string into memory. +// This is used on tvOS where file writes to App Group containers are blocked. +// When set, IsLoginRequired() and Run() will use this preloaded config instead of reading from file. +func (c *Client) SetConfigFromJSON(jsonStr string) error { + cfg, err := profilemanager.ConfigFromJSON(jsonStr) + if err != nil { + log.Errorf("SetConfigFromJSON: failed to parse config JSON: %v", err) + return err + } + c.preloadedConfig = cfg + log.Infof("SetConfigFromJSON: config loaded successfully from JSON") + return nil +} + // Run start the internal client. It is a blocker function func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error { exportEnvList(envList) log.Infof("Starting NetBird client") log.Debugf("Tunnel uses interface: %s", interfaceName) - cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ - ConfigPath: c.cfgFile, - StateFilePath: c.stateFile, - }) - if err != nil { - return err + + var cfg *profilemanager.Config + var err error + + // Use preloaded config if available (tvOS where file writes are blocked) + if c.preloadedConfig != nil { + log.Infof("Run: using preloaded config from memory") + cfg = c.preloadedConfig + } else { + log.Infof("Run: loading config from file") + // Use DirectUpdateOrCreateConfig to avoid atomic file operations (temp file + rename) + // which are blocked by the tvOS sandbox in App Group containers + cfg, err = profilemanager.DirectUpdateOrCreateConfig(profilemanager.ConfigInput{ + ConfigPath: c.cfgFile, + StateFilePath: c.stateFile, + }) + if err != nil { + return err + } } c.recorder.UpdateManagementAddress(cfg.ManagementURL.String()) c.recorder.UpdateRosenpass(cfg.RosenpassEnabled, cfg.RosenpassPermissive) @@ -120,7 +149,7 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error { c.ctxCancelLock.Unlock() auth := NewAuthWithConfig(ctx, cfg) - err = auth.Login() + err = auth.LoginSync() if err != nil { return err } @@ -131,7 +160,7 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error { c.onHostDnsFn = func([]string) {} cfg.WgIface = interfaceName - c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) + c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false) return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile) } @@ -208,14 +237,45 @@ func (c *Client) IsLoginRequired() bool { defer c.ctxCancelLock.Unlock() ctx, c.ctxCancel = context.WithCancel(ctxWithValues) - cfg, _ := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ - ConfigPath: c.cfgFile, - }) + var cfg *profilemanager.Config + var err error - needsLogin, _ := internal.IsLoginRequired(ctx, cfg) + // Use preloaded config if available (tvOS where file writes are blocked) + if c.preloadedConfig != nil { + log.Infof("IsLoginRequired: using preloaded config from memory") + cfg = c.preloadedConfig + } else { + log.Infof("IsLoginRequired: loading config from file") + // Use DirectUpdateOrCreateConfig to avoid atomic file operations (temp file + rename) + // which are blocked by the tvOS sandbox in App Group containers + cfg, err = profilemanager.DirectUpdateOrCreateConfig(profilemanager.ConfigInput{ + ConfigPath: c.cfgFile, + }) + if err != nil { + log.Errorf("IsLoginRequired: failed to load config: %v", err) + // If we can't load config, assume login is required + return true + } + } + + if cfg == nil { + log.Errorf("IsLoginRequired: config is nil") + return true + } + + needsLogin, err := internal.IsLoginRequired(ctx, cfg) + if err != nil { + log.Errorf("IsLoginRequired: check failed: %v", err) + // If the check fails, assume login is required to be safe + return true + } + log.Infof("IsLoginRequired: needsLogin=%v", needsLogin) return needsLogin } +// loginForMobileAuthTimeout is the timeout for requesting auth info from the server +const loginForMobileAuthTimeout = 30 * time.Second + func (c *Client) LoginForMobile() string { var ctx context.Context //nolint @@ -228,16 +288,26 @@ func (c *Client) LoginForMobile() string { defer c.ctxCancelLock.Unlock() ctx, c.ctxCancel = context.WithCancel(ctxWithValues) - cfg, _ := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ + // Use DirectUpdateOrCreateConfig to avoid atomic file operations (temp file + rename) + // which are blocked by the tvOS sandbox in App Group containers + cfg, err := profilemanager.DirectUpdateOrCreateConfig(profilemanager.ConfigInput{ ConfigPath: c.cfgFile, }) + if err != nil { + log.Errorf("LoginForMobile: failed to load config: %v", err) + return fmt.Sprintf("failed to load config: %v", err) + } oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, false, "") if err != nil { return err.Error() } - flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO()) + // Use a bounded timeout for the auth info request to prevent indefinite hangs + authInfoCtx, authInfoCancel := context.WithTimeout(ctx, loginForMobileAuthTimeout) + defer authInfoCancel() + + flowInfo, err := oAuthFlow.RequestAuthInfo(authInfoCtx) if err != nil { return err.Error() } @@ -249,10 +319,14 @@ func (c *Client) LoginForMobile() string { defer cancel() tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo) if err != nil { + log.Errorf("LoginForMobile: WaitToken failed: %v", err) return } jwtToken := tokenInfo.GetTokenToUse() - _ = internal.Login(ctx, cfg, "", jwtToken) + if err := internal.Login(ctx, cfg, "", jwtToken); err != nil { + log.Errorf("LoginForMobile: Login failed: %v", err) + return + } c.loginComplete = true }() diff --git a/client/ios/NetBirdSDK/login.go b/client/ios/NetBirdSDK/login.go index 1c2b38a61..27fdcf5ef 100644 --- a/client/ios/NetBirdSDK/login.go +++ b/client/ios/NetBirdSDK/login.go @@ -14,6 +14,7 @@ import ( "github.com/netbirdio/netbird/client/cmd" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/auth" "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/system" ) @@ -33,7 +34,8 @@ type ErrListener interface { // URLOpener it is a callback interface. The Open function will be triggered if // the backend want to show an url for the user type URLOpener interface { - Open(string) + Open(url string, userCode string) + OnLoginSuccess() } // Auth can register or login new client @@ -72,13 +74,32 @@ func NewAuthWithConfig(ctx context.Context, config *profilemanager.Config) *Auth // SaveConfigIfSSOSupported test the connectivity with the management server by retrieving the server device flow info. // If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO // is not supported and returns false without saving the configuration. For other errors return false. -func (a *Auth) SaveConfigIfSSOSupported() (bool, error) { +func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) { + if listener == nil { + log.Errorf("SaveConfigIfSSOSupported: listener is nil") + return + } + go func() { + sso, err := a.saveConfigIfSSOSupported() + if err != nil { + listener.OnError(err) + } else { + listener.OnSuccess(sso) + } + }() +} + +func (a *Auth) saveConfigIfSSOSupported() (bool, error) { supportsSSO := true err := a.withBackOff(a.ctx, func() (err error) { - _, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) + _, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil) if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) { - _, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil) - if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) { + _, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) + s, ok := gstatus.FromError(err) + if !ok { + return err + } + if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented { supportsSSO = false err = nil } @@ -97,12 +118,29 @@ func (a *Auth) SaveConfigIfSSOSupported() (bool, error) { return false, fmt.Errorf("backoff cycle failed: %v", err) } - err = profilemanager.WriteOutConfig(a.cfgPath, a.config) + // Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename) + // which are blocked by the tvOS sandbox in App Group containers + err = profilemanager.DirectWriteOutConfig(a.cfgPath, a.config) return true, err } // LoginWithSetupKeyAndSaveConfig test the connectivity with the management server with the setup key. -func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error { +func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupKey string, deviceName string) { + if resultListener == nil { + log.Errorf("LoginWithSetupKeyAndSaveConfig: resultListener is nil") + return + } + go func() { + err := a.loginWithSetupKeyAndSaveConfig(setupKey, deviceName) + if err != nil { + resultListener.OnError(err) + } else { + resultListener.OnSuccess() + } + }() +} + +func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error { //nolint ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName) @@ -118,10 +156,14 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string, deviceName string return fmt.Errorf("backoff cycle failed: %v", err) } - return profilemanager.WriteOutConfig(a.cfgPath, a.config) + // Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename) + // which are blocked by the tvOS sandbox in App Group containers + return profilemanager.DirectWriteOutConfig(a.cfgPath, a.config) } -func (a *Auth) Login() error { +// LoginSync performs a synchronous login check without UI interaction +// Used for background VPN connection where user should already be authenticated +func (a *Auth) LoginSync() error { var needsLogin bool // check if we need to generate JWT token @@ -135,23 +177,142 @@ func (a *Auth) Login() error { jwtToken := "" if needsLogin { - return fmt.Errorf("Not authenticated") + return fmt.Errorf("not authenticated") } err = a.withBackOff(a.ctx, func() error { err := internal.Login(a.ctx, a.config, "", jwtToken) - if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) { - return nil + if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) { + // PermissionDenied means registration is required or peer is blocked + return backoff.Permanent(err) } return err }) + if err != nil { + return fmt.Errorf("login failed: %v", err) + } + + return nil +} + +// Login performs interactive login with device authentication support +// Deprecated: Use LoginWithDeviceName instead to ensure proper device naming on tvOS +func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, forceDeviceAuth bool) { + // Use empty device name - system will use hostname as fallback + a.LoginWithDeviceName(resultListener, urlOpener, forceDeviceAuth, "") +} + +// LoginWithDeviceName performs interactive login with device authentication support +// The deviceName parameter allows specifying a custom device name (required for tvOS) +func (a *Auth) LoginWithDeviceName(resultListener ErrListener, urlOpener URLOpener, forceDeviceAuth bool, deviceName string) { + if resultListener == nil { + log.Errorf("LoginWithDeviceName: resultListener is nil") + return + } + if urlOpener == nil { + log.Errorf("LoginWithDeviceName: urlOpener is nil") + resultListener.OnError(fmt.Errorf("urlOpener is nil")) + return + } + go func() { + err := a.login(urlOpener, forceDeviceAuth, deviceName) + if err != nil { + resultListener.OnError(err) + } else { + resultListener.OnSuccess() + } + }() +} + +func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName string) error { + var needsLogin bool + + // Create context with device name if provided + ctx := a.ctx + if deviceName != "" { + //nolint:staticcheck + ctx = context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName) + } + + // check if we need to generate JWT token + err := a.withBackOff(ctx, func() (err error) { + needsLogin, err = internal.IsLoginRequired(ctx, a.config) + return + }) if err != nil { return fmt.Errorf("backoff cycle failed: %v", err) } + jwtToken := "" + if needsLogin { + tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, forceDeviceAuth) + if err != nil { + return fmt.Errorf("interactive sso login failed: %v", err) + } + jwtToken = tokenInfo.GetTokenToUse() + } + + err = a.withBackOff(ctx, func() error { + err := internal.Login(ctx, a.config, "", jwtToken) + if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) { + // PermissionDenied means registration is required or peer is blocked + return backoff.Permanent(err) + } + return err + }) + if err != nil { + return fmt.Errorf("login failed: %v", err) + } + + // Save the config before notifying success to ensure persistence completes + // before the callback potentially triggers teardown on the Swift side. + // Note: This differs from Android which doesn't save config after login. + // On iOS/tvOS, we save here because: + // 1. The config may have been modified during login (e.g., new tokens) + // 2. On tvOS, the Network Extension context may be the only place with + // write permissions to the App Group container + if a.cfgPath != "" { + if err := profilemanager.DirectWriteOutConfig(a.cfgPath, a.config); err != nil { + log.Warnf("failed to save config after login: %v", err) + } + } + + // Notify caller of successful login synchronously before returning + urlOpener.OnLoginSuccess() + return nil } +const authInfoRequestTimeout = 30 * time.Second + +func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, forceDeviceAuth bool) (*auth.TokenInfo, error) { + oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, forceDeviceAuth, "") + if err != nil { + return nil, err + } + + // Use a bounded timeout for the auth info request to prevent indefinite hangs + authInfoCtx, authInfoCancel := context.WithTimeout(a.ctx, authInfoRequestTimeout) + defer authInfoCancel() + + flowInfo, err := oAuthFlow.RequestAuthInfo(authInfoCtx) + if err != nil { + return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err) + } + + urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode) + + waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second + waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout) + defer cancel() + tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo) + if err != nil { + return nil, fmt.Errorf("waiting for browser login failed: %v", err) + } + + return &tokenInfo, nil +} + func (a *Auth) withBackOff(ctx context.Context, bf func() error) error { return backoff.RetryNotify( bf, @@ -160,3 +321,24 @@ func (a *Auth) withBackOff(ctx context.Context, bf func() error) error { log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err) }) } + +// GetConfigJSON returns the current config as a JSON string. +// This can be used by the caller to persist the config via alternative storage +// mechanisms (e.g., UserDefaults on tvOS where file writes are blocked). +func (a *Auth) GetConfigJSON() (string, error) { + if a.config == nil { + return "", fmt.Errorf("no config available") + } + return profilemanager.ConfigToJSON(a.config) +} + +// SetConfigFromJSON loads config from a JSON string. +// This can be used to restore config from alternative storage mechanisms. +func (a *Auth) SetConfigFromJSON(jsonStr string) error { + cfg, err := profilemanager.ConfigFromJSON(jsonStr) + if err != nil { + return err + } + a.config = cfg + return nil +} diff --git a/client/ios/NetBirdSDK/preferences.go b/client/ios/NetBirdSDK/preferences.go index 39ae06538..c26a6decd 100644 --- a/client/ios/NetBirdSDK/preferences.go +++ b/client/ios/NetBirdSDK/preferences.go @@ -112,6 +112,8 @@ func (p *Preferences) GetRosenpassPermissive() (bool, error) { // Commit write out the changes into config file func (p *Preferences) Commit() error { - _, err := profilemanager.UpdateOrCreateConfig(p.configInput) + // Use DirectUpdateOrCreateConfig to avoid atomic file operations (temp file + rename) + // which are blocked by the tvOS sandbox in App Group containers + _, err := profilemanager.DirectUpdateOrCreateConfig(p.configInput) return err } diff --git a/client/netbird.wxs b/client/netbird.wxs index ba827debf..03221dd91 100644 --- a/client/netbird.wxs +++ b/client/netbird.wxs @@ -51,7 +51,7 @@ - + diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 28e8b2d4e..5d56befc7 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.6 -// protoc v6.32.1 +// protoc v3.21.12 // source: daemon.proto package proto @@ -893,6 +893,7 @@ type UpRequest struct { state protoimpl.MessageState `protogen:"open.v1"` ProfileName *string `protobuf:"bytes,1,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"` Username *string `protobuf:"bytes,2,opt,name=username,proto3,oneof" json:"username,omitempty"` + AutoUpdate *bool `protobuf:"varint,3,opt,name=autoUpdate,proto3,oneof" json:"autoUpdate,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -941,6 +942,13 @@ func (x *UpRequest) GetUsername() string { return "" } +func (x *UpRequest) GetAutoUpdate() bool { + if x != nil && x.AutoUpdate != nil { + return *x.AutoUpdate + } + return false +} + type UpResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -2005,6 +2013,7 @@ type SSHSessionInfo struct { RemoteAddress string `protobuf:"bytes,2,opt,name=remoteAddress,proto3" json:"remoteAddress,omitempty"` Command string `protobuf:"bytes,3,opt,name=command,proto3" json:"command,omitempty"` JwtUsername string `protobuf:"bytes,4,opt,name=jwtUsername,proto3" json:"jwtUsername,omitempty"` + PortForwards []string `protobuf:"bytes,5,rep,name=portForwards,proto3" json:"portForwards,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -2067,6 +2076,13 @@ func (x *SSHSessionInfo) GetJwtUsername() string { return "" } +func (x *SSHSessionInfo) GetPortForwards() []string { + if x != nil { + return x.PortForwards + } + return nil +} + // SSHServerState contains the latest state of the SSH server type SSHServerState struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -5356,6 +5372,94 @@ func (x *WaitJWTTokenResponse) GetExpiresIn() int64 { return 0 } +type InstallerResultRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *InstallerResultRequest) Reset() { + *x = InstallerResultRequest{} + mi := &file_daemon_proto_msgTypes[79] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *InstallerResultRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*InstallerResultRequest) ProtoMessage() {} + +func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[79] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use InstallerResultRequest.ProtoReflect.Descriptor instead. +func (*InstallerResultRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{79} +} + +type InstallerResultResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"` + ErrorMsg string `protobuf:"bytes,2,opt,name=errorMsg,proto3" json:"errorMsg,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *InstallerResultResponse) Reset() { + *x = InstallerResultResponse{} + mi := &file_daemon_proto_msgTypes[80] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *InstallerResultResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*InstallerResultResponse) ProtoMessage() {} + +func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[80] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use InstallerResultResponse.ProtoReflect.Descriptor instead. +func (*InstallerResultResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{80} +} + +func (x *InstallerResultResponse) GetSuccess() bool { + if x != nil { + return x.Success + } + return false +} + +func (x *InstallerResultResponse) GetErrorMsg() string { + if x != nil { + return x.ErrorMsg + } + return "" +} + type PortInfo_Range struct { state protoimpl.MessageState `protogen:"open.v1"` Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"` @@ -5366,7 +5470,7 @@ type PortInfo_Range struct { func (x *PortInfo_Range) Reset() { *x = PortInfo_Range{} - mi := &file_daemon_proto_msgTypes[80] + mi := &file_daemon_proto_msgTypes[82] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5378,7 +5482,7 @@ func (x *PortInfo_Range) String() string { func (*PortInfo_Range) ProtoMessage() {} func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[80] + mi := &file_daemon_proto_msgTypes[82] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5502,12 +5606,16 @@ const file_daemon_proto_rawDesc = "" + "\buserCode\x18\x01 \x01(\tR\buserCode\x12\x1a\n" + "\bhostname\x18\x02 \x01(\tR\bhostname\",\n" + "\x14WaitSSOLoginResponse\x12\x14\n" + - "\x05email\x18\x01 \x01(\tR\x05email\"p\n" + + "\x05email\x18\x01 \x01(\tR\x05email\"\xa4\x01\n" + "\tUpRequest\x12%\n" + "\vprofileName\x18\x01 \x01(\tH\x00R\vprofileName\x88\x01\x01\x12\x1f\n" + - "\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" + + "\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01\x12#\n" + + "\n" + + "autoUpdate\x18\x03 \x01(\bH\x02R\n" + + "autoUpdate\x88\x01\x01B\x0e\n" + "\f_profileNameB\v\n" + - "\t_username\"\f\n" + + "\t_usernameB\r\n" + + "\v_autoUpdate\"\f\n" + "\n" + "UpResponse\"\xa1\x01\n" + "\rStatusRequest\x12,\n" + @@ -5606,12 +5714,13 @@ const file_daemon_proto_rawDesc = "" + "\aservers\x18\x01 \x03(\tR\aservers\x12\x18\n" + "\adomains\x18\x02 \x03(\tR\adomains\x12\x18\n" + "\aenabled\x18\x03 \x01(\bR\aenabled\x12\x14\n" + - "\x05error\x18\x04 \x01(\tR\x05error\"\x8e\x01\n" + + "\x05error\x18\x04 \x01(\tR\x05error\"\xb2\x01\n" + "\x0eSSHSessionInfo\x12\x1a\n" + "\busername\x18\x01 \x01(\tR\busername\x12$\n" + "\rremoteAddress\x18\x02 \x01(\tR\rremoteAddress\x12\x18\n" + "\acommand\x18\x03 \x01(\tR\acommand\x12 \n" + - "\vjwtUsername\x18\x04 \x01(\tR\vjwtUsername\"^\n" + + "\vjwtUsername\x18\x04 \x01(\tR\vjwtUsername\x12\"\n" + + "\fportForwards\x18\x05 \x03(\tR\fportForwards\"^\n" + "\x0eSSHServerState\x12\x18\n" + "\aenabled\x18\x01 \x01(\bR\aenabled\x122\n" + "\bsessions\x18\x02 \x03(\v2\x16.daemon.SSHSessionInfoR\bsessions\"\xaf\x04\n" + @@ -5893,7 +6002,11 @@ const file_daemon_proto_rawDesc = "" + "\x14WaitJWTTokenResponse\x12\x14\n" + "\x05token\x18\x01 \x01(\tR\x05token\x12\x1c\n" + "\ttokenType\x18\x02 \x01(\tR\ttokenType\x12\x1c\n" + - "\texpiresIn\x18\x03 \x01(\x03R\texpiresIn*b\n" + + "\texpiresIn\x18\x03 \x01(\x03R\texpiresIn\"\x18\n" + + "\x16InstallerResultRequest\"O\n" + + "\x17InstallerResultResponse\x12\x18\n" + + "\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" + + "\berrorMsg\x18\x02 \x01(\tR\berrorMsg*b\n" + "\bLogLevel\x12\v\n" + "\aUNKNOWN\x10\x00\x12\t\n" + "\x05PANIC\x10\x01\x12\t\n" + @@ -5902,7 +6015,7 @@ const file_daemon_proto_rawDesc = "" + "\x04WARN\x10\x04\x12\b\n" + "\x04INFO\x10\x05\x12\t\n" + "\x05DEBUG\x10\x06\x12\t\n" + - "\x05TRACE\x10\a2\xdb\x12\n" + + "\x05TRACE\x10\a2\xb4\x13\n" + "\rDaemonService\x126\n" + "\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" + "\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" + @@ -5938,7 +6051,8 @@ const file_daemon_proto_rawDesc = "" + "\x11GetPeerSSHHostKey\x12 .daemon.GetPeerSSHHostKeyRequest\x1a!.daemon.GetPeerSSHHostKeyResponse\"\x00\x12Q\n" + "\x0eRequestJWTAuth\x12\x1d.daemon.RequestJWTAuthRequest\x1a\x1e.daemon.RequestJWTAuthResponse\"\x00\x12K\n" + "\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00\x12N\n" + - "\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00B\bZ\x06/protob\x06proto3" + "\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00\x12W\n" + + "\x12GetInstallerResult\x12\x1e.daemon.InstallerResultRequest\x1a\x1f.daemon.InstallerResultResponse\"\x00B\bZ\x06/protob\x06proto3" var ( file_daemon_proto_rawDescOnce sync.Once @@ -5953,7 +6067,7 @@ func file_daemon_proto_rawDescGZIP() []byte { } var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 4) -var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 82) +var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 84) var file_daemon_proto_goTypes = []any{ (LogLevel)(0), // 0: daemon.LogLevel (OSLifecycleRequest_CycleType)(0), // 1: daemon.OSLifecycleRequest.CycleType @@ -6038,19 +6152,21 @@ var file_daemon_proto_goTypes = []any{ (*RequestJWTAuthResponse)(nil), // 80: daemon.RequestJWTAuthResponse (*WaitJWTTokenRequest)(nil), // 81: daemon.WaitJWTTokenRequest (*WaitJWTTokenResponse)(nil), // 82: daemon.WaitJWTTokenResponse - nil, // 83: daemon.Network.ResolvedIPsEntry - (*PortInfo_Range)(nil), // 84: daemon.PortInfo.Range - nil, // 85: daemon.SystemEvent.MetadataEntry - (*durationpb.Duration)(nil), // 86: google.protobuf.Duration - (*timestamppb.Timestamp)(nil), // 87: google.protobuf.Timestamp + (*InstallerResultRequest)(nil), // 83: daemon.InstallerResultRequest + (*InstallerResultResponse)(nil), // 84: daemon.InstallerResultResponse + nil, // 85: daemon.Network.ResolvedIPsEntry + (*PortInfo_Range)(nil), // 86: daemon.PortInfo.Range + nil, // 87: daemon.SystemEvent.MetadataEntry + (*durationpb.Duration)(nil), // 88: google.protobuf.Duration + (*timestamppb.Timestamp)(nil), // 89: google.protobuf.Timestamp } var file_daemon_proto_depIdxs = []int32{ 1, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType - 86, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 88, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration 27, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus - 87, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp - 87, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp - 86, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration + 89, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp + 89, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp + 88, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration 25, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo 22, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState 21, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState @@ -6061,8 +6177,8 @@ var file_daemon_proto_depIdxs = []int32{ 57, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent 26, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState 33, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network - 83, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry - 84, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range + 85, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry + 86, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range 34, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo 34, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo 35, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule @@ -6073,10 +6189,10 @@ var file_daemon_proto_depIdxs = []int32{ 54, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage 2, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity 3, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category - 87, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp - 85, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry + 89, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp + 87, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry 57, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent - 86, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 88, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration 70, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile 32, // 33: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList 7, // 34: daemon.DaemonService.Login:input_type -> daemon.LoginRequest @@ -6111,40 +6227,42 @@ var file_daemon_proto_depIdxs = []int32{ 79, // 63: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest 81, // 64: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest 5, // 65: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest - 8, // 66: daemon.DaemonService.Login:output_type -> daemon.LoginResponse - 10, // 67: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse - 12, // 68: daemon.DaemonService.Up:output_type -> daemon.UpResponse - 14, // 69: daemon.DaemonService.Status:output_type -> daemon.StatusResponse - 16, // 70: daemon.DaemonService.Down:output_type -> daemon.DownResponse - 18, // 71: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse - 29, // 72: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse - 31, // 73: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse - 31, // 74: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse - 36, // 75: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse - 38, // 76: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse - 40, // 77: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse - 42, // 78: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse - 45, // 79: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse - 47, // 80: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse - 49, // 81: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse - 51, // 82: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse - 55, // 83: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse - 57, // 84: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent - 59, // 85: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse - 61, // 86: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse - 63, // 87: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse - 65, // 88: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse - 67, // 89: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse - 69, // 90: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse - 72, // 91: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse - 74, // 92: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse - 76, // 93: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse - 78, // 94: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse - 80, // 95: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse - 82, // 96: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse - 6, // 97: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse - 66, // [66:98] is the sub-list for method output_type - 34, // [34:66] is the sub-list for method input_type + 83, // 66: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest + 8, // 67: daemon.DaemonService.Login:output_type -> daemon.LoginResponse + 10, // 68: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse + 12, // 69: daemon.DaemonService.Up:output_type -> daemon.UpResponse + 14, // 70: daemon.DaemonService.Status:output_type -> daemon.StatusResponse + 16, // 71: daemon.DaemonService.Down:output_type -> daemon.DownResponse + 18, // 72: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse + 29, // 73: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse + 31, // 74: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse + 31, // 75: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse + 36, // 76: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse + 38, // 77: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse + 40, // 78: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse + 42, // 79: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse + 45, // 80: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse + 47, // 81: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse + 49, // 82: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse + 51, // 83: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse + 55, // 84: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse + 57, // 85: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent + 59, // 86: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse + 61, // 87: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse + 63, // 88: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse + 65, // 89: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse + 67, // 90: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse + 69, // 91: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse + 72, // 92: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse + 74, // 93: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse + 76, // 94: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse + 78, // 95: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse + 80, // 96: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse + 82, // 97: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse + 6, // 98: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse + 84, // 99: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse + 67, // [67:100] is the sub-list for method output_type + 34, // [34:67] is the sub-list for method input_type 34, // [34:34] is the sub-list for extension type_name 34, // [34:34] is the sub-list for extension extendee 0, // [0:34] is the sub-list for field type_name @@ -6174,7 +6292,7 @@ func file_daemon_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)), NumEnums: 4, - NumMessages: 82, + NumMessages: 84, NumExtensions: 0, NumServices: 1, }, diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 3dfd3da8d..b75ca821a 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -95,6 +95,8 @@ service DaemonService { rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {} rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {} + + rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {} } @@ -215,6 +217,7 @@ message WaitSSOLoginResponse { message UpRequest { optional string profileName = 1; optional string username = 2; + optional bool autoUpdate = 3; } message UpResponse {} @@ -369,6 +372,7 @@ message SSHSessionInfo { string remoteAddress = 2; string command = 3; string jwtUsername = 4; + repeated string portForwards = 5; } // SSHServerState contains the latest state of the SSH server @@ -772,3 +776,11 @@ message WaitJWTTokenResponse { // expiration time in seconds int64 expiresIn = 3; } + +message InstallerResultRequest { +} + +message InstallerResultResponse { + bool success = 1; + string errorMsg = 2; +} diff --git a/client/proto/daemon_grpc.pb.go b/client/proto/daemon_grpc.pb.go index 6b01309b7..fdabb1879 100644 --- a/client/proto/daemon_grpc.pb.go +++ b/client/proto/daemon_grpc.pb.go @@ -71,6 +71,7 @@ type DaemonServiceClient interface { // WaitJWTToken waits for JWT authentication completion WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error) NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error) + GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error) } type daemonServiceClient struct { @@ -392,6 +393,15 @@ func (c *daemonServiceClient) NotifyOSLifecycle(ctx context.Context, in *OSLifec return out, nil } +func (c *daemonServiceClient) GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error) { + out := new(InstallerResultResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetInstallerResult", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + // DaemonServiceServer is the server API for DaemonService service. // All implementations must embed UnimplementedDaemonServiceServer // for forward compatibility @@ -449,6 +459,7 @@ type DaemonServiceServer interface { // WaitJWTToken waits for JWT authentication completion WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) + GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error) mustEmbedUnimplementedDaemonServiceServer() } @@ -552,6 +563,9 @@ func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTo func (UnimplementedDaemonServiceServer) NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method NotifyOSLifecycle not implemented") } +func (UnimplementedDaemonServiceServer) GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method GetInstallerResult not implemented") +} func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {} // UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service. @@ -1144,6 +1158,24 @@ func _DaemonService_NotifyOSLifecycle_Handler(srv interface{}, ctx context.Conte return interceptor(ctx, in, info, handler) } +func _DaemonService_GetInstallerResult_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(InstallerResultRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).GetInstallerResult(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/GetInstallerResult", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).GetInstallerResult(ctx, req.(*InstallerResultRequest)) + } + return interceptor(ctx, in, info, handler) +} + // DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -1275,6 +1307,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{ MethodName: "NotifyOSLifecycle", Handler: _DaemonService_NotifyOSLifecycle_Handler, }, + { + MethodName: "GetInstallerResult", + Handler: _DaemonService_GetInstallerResult_Handler, + }, }, Streams: []grpc.StreamDesc{ { diff --git a/client/proto/generate.sh b/client/proto/generate.sh index f9a2c3750..e659cef90 100755 --- a/client/proto/generate.sh +++ b/client/proto/generate.sh @@ -14,4 +14,4 @@ cd "$script_path" go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.36.6 go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1 protoc -I ./ ./daemon.proto --go_out=../ --go-grpc_out=../ --experimental_allow_proto3_optional -cd "$old_pwd" \ No newline at end of file +cd "$old_pwd" diff --git a/client/server/server.go b/client/server/server.go index d33595115..7b6c4e98c 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -145,10 +145,10 @@ func (s *Server) Start() error { ctx, cancel := context.WithCancel(s.rootCtx) s.actCancel = cancel - // set the default config if not exists - if err := s.setDefaultConfigIfNotExists(ctx); err != nil { - log.Errorf("failed to set default config: %v", err) - return fmt.Errorf("failed to set default config: %w", err) + // copy old default config + _, err = s.profileManager.CopyDefaultProfileIfNotExists() + if err != nil && !errors.Is(err, profilemanager.ErrorOldDefaultConfigNotFound) { + return err } activeProf, err := s.profileManager.GetActiveProfileState() @@ -156,23 +156,11 @@ func (s *Server) Start() error { return fmt.Errorf("failed to get active profile state: %w", err) } - config, err := s.getConfig(activeProf) + config, existingConfig, err := s.getConfig(activeProf) if err != nil { log.Errorf("failed to get active profile config: %v", err) - if err := s.profileManager.SetActiveProfileState(&profilemanager.ActiveProfileState{ - Name: "default", - Username: "", - }); err != nil { - log.Errorf("failed to set active profile state: %v", err) - return fmt.Errorf("failed to set active profile state: %w", err) - } - - config, err = profilemanager.GetConfig(s.profileManager.DefaultProfilePath()) - if err != nil { - log.Errorf("failed to get default profile config: %v", err) - return fmt.Errorf("failed to get default profile config: %w", err) - } + return err } s.config = config @@ -186,44 +174,27 @@ func (s *Server) Start() error { } if config.DisableAutoConnect { + state.Set(internal.StatusIdle) + return nil + } + + if !existingConfig { + log.Warnf("not trying to connect when configuration was just created") + state.Set(internal.StatusNeedsLogin) return nil } s.clientRunning = true s.clientRunningChan = make(chan struct{}) s.clientGiveUpChan = make(chan struct{}) - go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan) - return nil -} - -func (s *Server) setDefaultConfigIfNotExists(ctx context.Context) error { - ok, err := s.profileManager.CopyDefaultProfileIfNotExists() - if err != nil { - if err := s.profileManager.CreateDefaultProfile(); err != nil { - log.Errorf("failed to create default profile: %v", err) - return fmt.Errorf("failed to create default profile: %w", err) - } - - if err := s.profileManager.SetActiveProfileState(&profilemanager.ActiveProfileState{ - Name: "default", - Username: "", - }); err != nil { - log.Errorf("failed to set active profile state: %v", err) - return fmt.Errorf("failed to set active profile state: %w", err) - } - } - if ok { - state := internal.CtxGetState(ctx) - state.Set(internal.StatusNeedsLogin) - } - + go s.connectWithRetryRuns(ctx, config, s.statusRecorder, false, s.clientRunningChan, s.clientGiveUpChan) return nil } // connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional // mechanism to keep the client connected even when the connection is lost. // we cancel retry if the client receive a stop or down command, or if disable auto connect is configured. -func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) { +func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, doInitialAutoUpdate bool, runningChan chan struct{}, giveUpChan chan struct{}) { defer func() { s.mutex.Lock() s.clientRunning = false @@ -231,7 +202,7 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil }() if s.config.DisableAutoConnect { - if err := s.connect(ctx, s.config, s.statusRecorder, runningChan); err != nil { + if err := s.connect(ctx, s.config, s.statusRecorder, doInitialAutoUpdate, runningChan); err != nil { log.Debugf("run client connection exited with error: %v", err) } log.Tracef("client connection exited") @@ -260,7 +231,8 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil }() runOperation := func() error { - err := s.connect(ctx, profileConfig, statusRecorder, runningChan) + err := s.connect(ctx, profileConfig, statusRecorder, doInitialAutoUpdate, runningChan) + doInitialAutoUpdate = false if err != nil { log.Debugf("run client connection exited with error: %v. Will retry in the background", err) return err @@ -486,7 +458,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro s.mutex.Unlock() - config, err := s.getConfig(activeProf) + config, _, err := s.getConfig(activeProf) if err != nil { log.Errorf("failed to get active profile config: %v", err) return nil, fmt.Errorf("failed to get active profile config: %w", err) @@ -715,7 +687,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR log.Infof("active profile: %s for %s", activeProf.Name, activeProf.Username) - config, err := s.getConfig(activeProf) + config, _, err := s.getConfig(activeProf) if err != nil { log.Errorf("failed to get active profile config: %v", err) return nil, fmt.Errorf("failed to get active profile config: %w", err) @@ -728,7 +700,12 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR s.clientRunning = true s.clientRunningChan = make(chan struct{}) s.clientGiveUpChan = make(chan struct{}) - go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan) + + var doAutoUpdate bool + if msg != nil && msg.AutoUpdate != nil && *msg.AutoUpdate { + doAutoUpdate = true + } + go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, doAutoUpdate, s.clientRunningChan, s.clientGiveUpChan) return s.waitForUp(callerCtx) } @@ -805,7 +782,7 @@ func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfi log.Errorf("failed to get active profile state: %v", err) return nil, fmt.Errorf("failed to get active profile state: %w", err) } - config, err := s.getConfig(activeProf) + config, _, err := s.getConfig(activeProf) if err != nil { log.Errorf("failed to get default profile config: %v", err) return nil, fmt.Errorf("failed to get default profile config: %w", err) @@ -902,7 +879,7 @@ func (s *Server) handleActiveProfileLogout(ctx context.Context) (*proto.LogoutRe return nil, gstatus.Errorf(codes.FailedPrecondition, "failed to get active profile state: %v", err) } - config, err := s.getConfig(activeProf) + config, _, err := s.getConfig(activeProf) if err != nil { return nil, gstatus.Errorf(codes.FailedPrecondition, "not logged in") } @@ -926,19 +903,24 @@ func (s *Server) handleActiveProfileLogout(ctx context.Context) (*proto.LogoutRe return &proto.LogoutResponse{}, nil } -// getConfig loads the config from the active profile -func (s *Server) getConfig(activeProf *profilemanager.ActiveProfileState) (*profilemanager.Config, error) { +// GetConfig reads config file and returns Config and whether the config file already existed. Errors out if it does not exist +func (s *Server) getConfig(activeProf *profilemanager.ActiveProfileState) (*profilemanager.Config, bool, error) { cfgPath, err := activeProf.FilePath() if err != nil { - return nil, fmt.Errorf("failed to get active profile file path: %w", err) + return nil, false, fmt.Errorf("failed to get active profile file path: %w", err) } - config, err := profilemanager.GetConfig(cfgPath) + _, err = os.Stat(cfgPath) + configExisted := !os.IsNotExist(err) + + log.Infof("active profile config existed: %t, err %v", configExisted, err) + + config, err := profilemanager.ReadConfig(cfgPath) if err != nil { - return nil, fmt.Errorf("failed to get config: %w", err) + return nil, false, fmt.Errorf("failed to get config: %w", err) } - return config, nil + return config, configExisted, nil } func (s *Server) canRemoveProfile(profileName string) error { @@ -1122,6 +1104,7 @@ func (s *Server) getSSHServerState() *proto.SSHServerState { RemoteAddress: session.RemoteAddress, Command: session.Command, JwtUsername: session.JWTUsername, + PortForwards: session.PortForwards, }) } @@ -1539,9 +1522,9 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest) return features, nil } -func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) error { +func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, doInitialAutoUpdate bool, runningChan chan struct{}) error { log.Tracef("running client connection") - s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder) + s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder, doInitialAutoUpdate) s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse) if err := s.connectClient.Run(runningChan); err != nil { return err diff --git a/client/server/server_test.go b/client/server/server_test.go index 5f28a2664..1ed115769 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -112,7 +112,7 @@ func TestConnectWithRetryRuns(t *testing.T) { t.Setenv(maxRetryTimeVar, "5s") t.Setenv(retryMultiplierVar, "1") - s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil) + s.connectWithRetryRuns(ctx, config, s.statusRecorder, false, nil, nil) if counter < 3 { t.Fatalf("expected counter > 2, got %d", counter) } @@ -326,7 +326,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve if err != nil { return nil, "", err } - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil) if err != nil { return nil, "", err } diff --git a/client/server/updateresult.go b/client/server/updateresult.go new file mode 100644 index 000000000..8e00d5062 --- /dev/null +++ b/client/server/updateresult.go @@ -0,0 +1,30 @@ +package server + +import ( + "context" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/updatemanager/installer" + "github.com/netbirdio/netbird/client/proto" +) + +func (s *Server) GetInstallerResult(ctx context.Context, _ *proto.InstallerResultRequest) (*proto.InstallerResultResponse, error) { + inst := installer.New() + dir := inst.TempDir() + + rh := installer.NewResultHandler(dir) + result, err := rh.Watch(ctx) + if err != nil { + log.Errorf("failed to watch update result: %v", err) + return &proto.InstallerResultResponse{ + Success: false, + ErrorMsg: err.Error(), + }, nil + } + + return &proto.InstallerResultResponse{ + Success: result.Success, + ErrorMsg: result.Error, + }, nil +} diff --git a/client/ssh/auth/auth.go b/client/ssh/auth/auth.go new file mode 100644 index 000000000..079282fdc --- /dev/null +++ b/client/ssh/auth/auth.go @@ -0,0 +1,177 @@ +package auth + +import ( + "errors" + "fmt" + "sync" + + log "github.com/sirupsen/logrus" + + sshuserhash "github.com/netbirdio/netbird/shared/sshauth" +) + +const ( + // DefaultUserIDClaim is the default JWT claim used to extract user IDs + DefaultUserIDClaim = "sub" + // Wildcard is a special user ID that matches all users + Wildcard = "*" +) + +var ( + ErrEmptyUserID = errors.New("JWT user ID is empty") + ErrUserNotAuthorized = errors.New("user is not authorized to access this peer") + ErrNoMachineUserMapping = errors.New("no authorization mapping for OS user") + ErrUserNotMappedToOSUser = errors.New("user is not authorized to login as OS user") +) + +// Authorizer handles SSH fine-grained access control authorization +type Authorizer struct { + // UserIDClaim is the JWT claim to extract the user ID from + userIDClaim string + + // authorizedUsers is a list of hashed user IDs authorized to access this peer + authorizedUsers []sshuserhash.UserIDHash + + // machineUsers maps OS login usernames to lists of authorized user indexes + machineUsers map[string][]uint32 + + // mu protects the list of users + mu sync.RWMutex +} + +// Config contains configuration for the SSH authorizer +type Config struct { + // UserIDClaim is the JWT claim to extract the user ID from (e.g., "sub", "email") + UserIDClaim string + + // AuthorizedUsers is a list of hashed user IDs (FNV-1a 64-bit) authorized to access this peer + AuthorizedUsers []sshuserhash.UserIDHash + + // MachineUsers maps OS login usernames to indexes in AuthorizedUsers + // If a user wants to login as a specific OS user, their index must be in the corresponding list + MachineUsers map[string][]uint32 +} + +// NewAuthorizer creates a new SSH authorizer with empty configuration +func NewAuthorizer() *Authorizer { + a := &Authorizer{ + userIDClaim: DefaultUserIDClaim, + machineUsers: make(map[string][]uint32), + } + + return a +} + +// Update updates the authorizer configuration with new values +func (a *Authorizer) Update(config *Config) { + a.mu.Lock() + defer a.mu.Unlock() + + if config == nil { + // Clear authorization + a.userIDClaim = DefaultUserIDClaim + a.authorizedUsers = []sshuserhash.UserIDHash{} + a.machineUsers = make(map[string][]uint32) + log.Info("SSH authorization cleared") + return + } + + userIDClaim := config.UserIDClaim + if userIDClaim == "" { + userIDClaim = DefaultUserIDClaim + } + a.userIDClaim = userIDClaim + + // Store authorized users list + a.authorizedUsers = config.AuthorizedUsers + + // Store machine users mapping + machineUsers := make(map[string][]uint32) + for osUser, indexes := range config.MachineUsers { + if len(indexes) > 0 { + machineUsers[osUser] = indexes + } + } + a.machineUsers = machineUsers + + log.Debugf("SSH auth: updated with %d authorized users, %d machine user mappings", + len(config.AuthorizedUsers), len(machineUsers)) +} + +// Authorize validates if a user is authorized to login as the specified OS user. +// Returns a success message describing how authorization was granted, or an error. +func (a *Authorizer) Authorize(jwtUserID, osUsername string) (string, error) { + if jwtUserID == "" { + return "", fmt.Errorf("JWT user ID is empty for OS user %q: %w", osUsername, ErrEmptyUserID) + } + + // Hash the JWT user ID for comparison + hashedUserID, err := sshuserhash.HashUserID(jwtUserID) + if err != nil { + return "", fmt.Errorf("hash user ID %q for OS user %q: %w", jwtUserID, osUsername, err) + } + + a.mu.RLock() + defer a.mu.RUnlock() + + // Find the index of this user in the authorized list + userIndex, found := a.findUserIndex(hashedUserID) + if !found { + return "", fmt.Errorf("user %q (hash: %s) not in authorized list for OS user %q: %w", jwtUserID, hashedUserID, osUsername, ErrUserNotAuthorized) + } + + return a.checkMachineUserMapping(jwtUserID, osUsername, userIndex) +} + +// checkMachineUserMapping validates if a user's index is authorized for the specified OS user +// Checks wildcard mapping first, then specific OS user mappings +func (a *Authorizer) checkMachineUserMapping(jwtUserID, osUsername string, userIndex int) (string, error) { + // If wildcard exists and user's index is in the wildcard list, allow access to any OS user + if wildcardIndexes, hasWildcard := a.machineUsers[Wildcard]; hasWildcard { + if a.isIndexInList(uint32(userIndex), wildcardIndexes) { + return fmt.Sprintf("granted via wildcard (index: %d)", userIndex), nil + } + } + + // Check for specific OS username mapping + allowedIndexes, hasMachineUserMapping := a.machineUsers[osUsername] + if !hasMachineUserMapping { + // No mapping for this OS user - deny by default (fail closed) + return "", fmt.Errorf("no machine user mapping for OS user %q (JWT user: %s): %w", osUsername, jwtUserID, ErrNoMachineUserMapping) + } + + // Check if user's index is in the allowed indexes for this specific OS user + if !a.isIndexInList(uint32(userIndex), allowedIndexes) { + return "", fmt.Errorf("user %q not mapped to OS user %q (index: %d): %w", jwtUserID, osUsername, userIndex, ErrUserNotMappedToOSUser) + } + + return fmt.Sprintf("granted (index: %d)", userIndex), nil +} + +// GetUserIDClaim returns the JWT claim name used to extract user IDs +func (a *Authorizer) GetUserIDClaim() string { + a.mu.RLock() + defer a.mu.RUnlock() + return a.userIDClaim +} + +// findUserIndex finds the index of a hashed user ID in the authorized users list +// Returns the index and true if found, 0 and false if not found +func (a *Authorizer) findUserIndex(hashedUserID sshuserhash.UserIDHash) (int, bool) { + for i, id := range a.authorizedUsers { + if id == hashedUserID { + return i, true + } + } + return 0, false +} + +// isIndexInList checks if an index exists in a list of indexes +func (a *Authorizer) isIndexInList(index uint32, indexes []uint32) bool { + for _, idx := range indexes { + if idx == index { + return true + } + } + return false +} diff --git a/client/ssh/auth/auth_test.go b/client/ssh/auth/auth_test.go new file mode 100644 index 000000000..fa27b72e8 --- /dev/null +++ b/client/ssh/auth/auth_test.go @@ -0,0 +1,612 @@ +package auth + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/shared/sshauth" +) + +func TestAuthorizer_Authorize_UserNotInList(t *testing.T) { + authorizer := NewAuthorizer() + + // Set up authorized users list with one user + authorizedUserHash, err := sshauth.HashUserID("authorized-user") + require.NoError(t, err) + + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{authorizedUserHash}, + MachineUsers: map[string][]uint32{}, + } + authorizer.Update(config) + + // Try to authorize a different user + _, err = authorizer.Authorize("unauthorized-user", "root") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrUserNotAuthorized) +} + +func TestAuthorizer_Authorize_UserInList_NoMachineUserRestrictions(t *testing.T) { + authorizer := NewAuthorizer() + + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + user2Hash, err := sshauth.HashUserID("user2") + require.NoError(t, err) + + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash}, + MachineUsers: map[string][]uint32{}, // Empty = deny all (fail closed) + } + authorizer.Update(config) + + // All attempts should fail when no machine user mappings exist (fail closed) + _, err = authorizer.Authorize("user1", "root") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrNoMachineUserMapping) + + _, err = authorizer.Authorize("user2", "admin") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrNoMachineUserMapping) + + _, err = authorizer.Authorize("user1", "postgres") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrNoMachineUserMapping) +} + +func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Allowed(t *testing.T) { + authorizer := NewAuthorizer() + + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + user2Hash, err := sshauth.HashUserID("user2") + require.NoError(t, err) + user3Hash, err := sshauth.HashUserID("user3") + require.NoError(t, err) + + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash}, + MachineUsers: map[string][]uint32{ + "root": {0, 1}, // user1 and user2 can access root + "postgres": {1, 2}, // user2 and user3 can access postgres + "admin": {0}, // only user1 can access admin + }, + } + authorizer.Update(config) + + // user1 (index 0) should access root and admin + _, err = authorizer.Authorize("user1", "root") + assert.NoError(t, err) + + _, err = authorizer.Authorize("user1", "admin") + assert.NoError(t, err) + + // user2 (index 1) should access root and postgres + _, err = authorizer.Authorize("user2", "root") + assert.NoError(t, err) + + _, err = authorizer.Authorize("user2", "postgres") + assert.NoError(t, err) + + // user3 (index 2) should access postgres + _, err = authorizer.Authorize("user3", "postgres") + assert.NoError(t, err) +} + +func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Denied(t *testing.T) { + authorizer := NewAuthorizer() + + // Set up authorized users list + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + user2Hash, err := sshauth.HashUserID("user2") + require.NoError(t, err) + user3Hash, err := sshauth.HashUserID("user3") + require.NoError(t, err) + + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash}, + MachineUsers: map[string][]uint32{ + "root": {0, 1}, // user1 and user2 can access root + "postgres": {1, 2}, // user2 and user3 can access postgres + "admin": {0}, // only user1 can access admin + }, + } + authorizer.Update(config) + + // user1 (index 0) should NOT access postgres + _, err = authorizer.Authorize("user1", "postgres") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrUserNotMappedToOSUser) + + // user2 (index 1) should NOT access admin + _, err = authorizer.Authorize("user2", "admin") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrUserNotMappedToOSUser) + + // user3 (index 2) should NOT access root + _, err = authorizer.Authorize("user3", "root") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrUserNotMappedToOSUser) + + // user3 (index 2) should NOT access admin + _, err = authorizer.Authorize("user3", "admin") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrUserNotMappedToOSUser) +} + +func TestAuthorizer_Authorize_UserInList_OSUserNotInMapping(t *testing.T) { + authorizer := NewAuthorizer() + + // Set up authorized users list + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash}, + MachineUsers: map[string][]uint32{ + "root": {0}, // only root is mapped + }, + } + authorizer.Update(config) + + // user1 should NOT access an unmapped OS user (fail closed) + _, err = authorizer.Authorize("user1", "postgres") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrNoMachineUserMapping) +} + +func TestAuthorizer_Authorize_EmptyJWTUserID(t *testing.T) { + authorizer := NewAuthorizer() + + // Set up authorized users list + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash}, + MachineUsers: map[string][]uint32{}, + } + authorizer.Update(config) + + // Empty user ID should fail + _, err = authorizer.Authorize("", "root") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrEmptyUserID) +} + +func TestAuthorizer_Authorize_MultipleUsersInList(t *testing.T) { + authorizer := NewAuthorizer() + + // Set up multiple authorized users + userHashes := make([]sshauth.UserIDHash, 10) + for i := 0; i < 10; i++ { + hash, err := sshauth.HashUserID("user" + string(rune('0'+i))) + require.NoError(t, err) + userHashes[i] = hash + } + + // Create machine user mapping for all users + rootIndexes := make([]uint32, 10) + for i := 0; i < 10; i++ { + rootIndexes[i] = uint32(i) + } + + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: userHashes, + MachineUsers: map[string][]uint32{ + "root": rootIndexes, + }, + } + authorizer.Update(config) + + // All users should be authorized for root + for i := 0; i < 10; i++ { + _, err := authorizer.Authorize("user"+string(rune('0'+i)), "root") + assert.NoError(t, err, "user%d should be authorized", i) + } + + // User not in list should fail + _, err := authorizer.Authorize("unknown-user", "root") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrUserNotAuthorized) +} + +func TestAuthorizer_Update_ClearsConfiguration(t *testing.T) { + authorizer := NewAuthorizer() + + // Set up initial configuration + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash}, + MachineUsers: map[string][]uint32{"root": {0}}, + } + authorizer.Update(config) + + // user1 should be authorized + _, err = authorizer.Authorize("user1", "root") + assert.NoError(t, err) + + // Clear configuration + authorizer.Update(nil) + + // user1 should no longer be authorized + _, err = authorizer.Authorize("user1", "root") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrUserNotAuthorized) +} + +func TestAuthorizer_Update_EmptyMachineUsersListEntries(t *testing.T) { + authorizer := NewAuthorizer() + + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + + // Machine users with empty index lists should be filtered out + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash}, + MachineUsers: map[string][]uint32{ + "root": {0}, + "postgres": {}, // empty list - should be filtered out + "admin": nil, // nil list - should be filtered out + }, + } + authorizer.Update(config) + + // root should work + _, err = authorizer.Authorize("user1", "root") + assert.NoError(t, err) + + // postgres should fail (no mapping) + _, err = authorizer.Authorize("user1", "postgres") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrNoMachineUserMapping) + + // admin should fail (no mapping) + _, err = authorizer.Authorize("user1", "admin") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrNoMachineUserMapping) +} + +func TestAuthorizer_CustomUserIDClaim(t *testing.T) { + authorizer := NewAuthorizer() + + // Set up with custom user ID claim + user1Hash, err := sshauth.HashUserID("user@example.com") + require.NoError(t, err) + + config := &Config{ + UserIDClaim: "email", + AuthorizedUsers: []sshauth.UserIDHash{user1Hash}, + MachineUsers: map[string][]uint32{ + "root": {0}, + }, + } + authorizer.Update(config) + + // Verify the custom claim is set + assert.Equal(t, "email", authorizer.GetUserIDClaim()) + + // Authorize with email as user ID + _, err = authorizer.Authorize("user@example.com", "root") + assert.NoError(t, err) +} + +func TestAuthorizer_DefaultUserIDClaim(t *testing.T) { + authorizer := NewAuthorizer() + + // Verify default claim + assert.Equal(t, DefaultUserIDClaim, authorizer.GetUserIDClaim()) + assert.Equal(t, "sub", authorizer.GetUserIDClaim()) + + // Set up with empty user ID claim (should use default) + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + + config := &Config{ + UserIDClaim: "", // empty - should use default + AuthorizedUsers: []sshauth.UserIDHash{user1Hash}, + MachineUsers: map[string][]uint32{}, + } + authorizer.Update(config) + + // Should fall back to default + assert.Equal(t, DefaultUserIDClaim, authorizer.GetUserIDClaim()) +} + +func TestAuthorizer_MachineUserMapping_LargeIndexes(t *testing.T) { + authorizer := NewAuthorizer() + + // Create a large authorized users list + const numUsers = 1000 + userHashes := make([]sshauth.UserIDHash, numUsers) + for i := 0; i < numUsers; i++ { + hash, err := sshauth.HashUserID("user" + string(rune(i))) + require.NoError(t, err) + userHashes[i] = hash + } + + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: userHashes, + MachineUsers: map[string][]uint32{ + "root": {0, 500, 999}, // first, middle, and last user + }, + } + authorizer.Update(config) + + // First user should have access + _, err := authorizer.Authorize("user"+string(rune(0)), "root") + assert.NoError(t, err) + + // Middle user should have access + _, err = authorizer.Authorize("user"+string(rune(500)), "root") + assert.NoError(t, err) + + // Last user should have access + _, err = authorizer.Authorize("user"+string(rune(999)), "root") + assert.NoError(t, err) + + // User not in mapping should NOT have access + _, err = authorizer.Authorize("user"+string(rune(100)), "root") + assert.Error(t, err) +} + +func TestAuthorizer_ConcurrentAuthorization(t *testing.T) { + authorizer := NewAuthorizer() + + // Set up authorized users + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + user2Hash, err := sshauth.HashUserID("user2") + require.NoError(t, err) + + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash}, + MachineUsers: map[string][]uint32{ + "root": {0, 1}, + }, + } + authorizer.Update(config) + + // Test concurrent authorization calls (should be safe to read concurrently) + const numGoroutines = 100 + errChan := make(chan error, numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func(idx int) { + user := "user1" + if idx%2 == 0 { + user = "user2" + } + _, err := authorizer.Authorize(user, "root") + errChan <- err + }(i) + } + + // Wait for all goroutines to complete and collect errors + for i := 0; i < numGoroutines; i++ { + err := <-errChan + assert.NoError(t, err) + } +} + +func TestAuthorizer_Wildcard_AllowsAllAuthorizedUsers(t *testing.T) { + authorizer := NewAuthorizer() + + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + user2Hash, err := sshauth.HashUserID("user2") + require.NoError(t, err) + user3Hash, err := sshauth.HashUserID("user3") + require.NoError(t, err) + + // Configure with wildcard - all authorized users can access any OS user + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash}, + MachineUsers: map[string][]uint32{ + "*": {0, 1, 2}, // wildcard with all user indexes + }, + } + authorizer.Update(config) + + // All authorized users should be able to access any OS user + _, err = authorizer.Authorize("user1", "root") + assert.NoError(t, err) + + _, err = authorizer.Authorize("user2", "postgres") + assert.NoError(t, err) + + _, err = authorizer.Authorize("user3", "admin") + assert.NoError(t, err) + + _, err = authorizer.Authorize("user1", "ubuntu") + assert.NoError(t, err) + + _, err = authorizer.Authorize("user2", "nginx") + assert.NoError(t, err) + + _, err = authorizer.Authorize("user3", "docker") + assert.NoError(t, err) +} + +func TestAuthorizer_Wildcard_UnauthorizedUserStillDenied(t *testing.T) { + authorizer := NewAuthorizer() + + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + + // Configure with wildcard + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash}, + MachineUsers: map[string][]uint32{ + "*": {0}, + }, + } + authorizer.Update(config) + + // user1 should have access + _, err = authorizer.Authorize("user1", "root") + assert.NoError(t, err) + + // Unauthorized user should still be denied even with wildcard + _, err = authorizer.Authorize("unauthorized-user", "root") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrUserNotAuthorized) +} + +func TestAuthorizer_Wildcard_TakesPrecedenceOverSpecificMappings(t *testing.T) { + authorizer := NewAuthorizer() + + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + user2Hash, err := sshauth.HashUserID("user2") + require.NoError(t, err) + + // Configure with both wildcard and specific mappings + // Wildcard takes precedence for users in the wildcard index list + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash}, + MachineUsers: map[string][]uint32{ + "*": {0, 1}, // wildcard for both users + "root": {0}, // specific mapping that would normally restrict to user1 only + }, + } + authorizer.Update(config) + + // Both users should be able to access root via wildcard (takes precedence over specific mapping) + _, err = authorizer.Authorize("user1", "root") + assert.NoError(t, err) + + _, err = authorizer.Authorize("user2", "root") + assert.NoError(t, err) + + // Both users should be able to access any other OS user via wildcard + _, err = authorizer.Authorize("user1", "postgres") + assert.NoError(t, err) + + _, err = authorizer.Authorize("user2", "admin") + assert.NoError(t, err) +} + +func TestAuthorizer_NoWildcard_SpecificMappingsOnly(t *testing.T) { + authorizer := NewAuthorizer() + + user1Hash, err := sshauth.HashUserID("user1") + require.NoError(t, err) + user2Hash, err := sshauth.HashUserID("user2") + require.NoError(t, err) + + // Configure WITHOUT wildcard - only specific mappings + config := &Config{ + UserIDClaim: DefaultUserIDClaim, + AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash}, + MachineUsers: map[string][]uint32{ + "root": {0}, // only user1 + "postgres": {1}, // only user2 + }, + } + authorizer.Update(config) + + // user1 can access root + _, err = authorizer.Authorize("user1", "root") + assert.NoError(t, err) + + // user2 can access postgres + _, err = authorizer.Authorize("user2", "postgres") + assert.NoError(t, err) + + // user1 cannot access postgres + _, err = authorizer.Authorize("user1", "postgres") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrUserNotMappedToOSUser) + + // user2 cannot access root + _, err = authorizer.Authorize("user2", "root") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrUserNotMappedToOSUser) + + // Neither can access unmapped OS users + _, err = authorizer.Authorize("user1", "admin") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrNoMachineUserMapping) + + _, err = authorizer.Authorize("user2", "admin") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrNoMachineUserMapping) +} + +func TestAuthorizer_Wildcard_WithPartialIndexes_AllowsAllUsers(t *testing.T) { + // This test covers the scenario where wildcard exists with limited indexes. + // Only users whose indexes are in the wildcard list can access any OS user via wildcard. + // Other users can only access OS users they are explicitly mapped to. + authorizer := NewAuthorizer() + + // Create two authorized user hashes (simulating the base64-encoded hashes in the config) + wasmHash, err := sshauth.HashUserID("wasm") + require.NoError(t, err) + user2Hash, err := sshauth.HashUserID("user2") + require.NoError(t, err) + + // Configure with wildcard having only index 0, and specific mappings for other OS users + config := &Config{ + UserIDClaim: "sub", + AuthorizedUsers: []sshauth.UserIDHash{wasmHash, user2Hash}, + MachineUsers: map[string][]uint32{ + "*": {0}, // wildcard with only index 0 - only wasm has wildcard access + "alice": {1}, // specific mapping for user2 + "bob": {1}, // specific mapping for user2 + }, + } + authorizer.Update(config) + + // wasm (index 0) should access any OS user via wildcard + _, err = authorizer.Authorize("wasm", "root") + assert.NoError(t, err, "wasm should access root via wildcard") + + _, err = authorizer.Authorize("wasm", "alice") + assert.NoError(t, err, "wasm should access alice via wildcard") + + _, err = authorizer.Authorize("wasm", "bob") + assert.NoError(t, err, "wasm should access bob via wildcard") + + _, err = authorizer.Authorize("wasm", "postgres") + assert.NoError(t, err, "wasm should access postgres via wildcard") + + // user2 (index 1) should only access alice and bob (explicitly mapped), NOT root or postgres + _, err = authorizer.Authorize("user2", "alice") + assert.NoError(t, err, "user2 should access alice via explicit mapping") + + _, err = authorizer.Authorize("user2", "bob") + assert.NoError(t, err, "user2 should access bob via explicit mapping") + + _, err = authorizer.Authorize("user2", "root") + assert.Error(t, err, "user2 should NOT access root (not in wildcard indexes)") + assert.ErrorIs(t, err, ErrNoMachineUserMapping) + + _, err = authorizer.Authorize("user2", "postgres") + assert.Error(t, err, "user2 should NOT access postgres (not explicitly mapped)") + assert.ErrorIs(t, err, ErrNoMachineUserMapping) + + // Unauthorized user should still be denied + _, err = authorizer.Authorize("user3", "root") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrUserNotAuthorized, "unauthorized user should be denied") +} diff --git a/client/ssh/client/client.go b/client/ssh/client/client.go index aab222093..342da7303 100644 --- a/client/ssh/client/client.go +++ b/client/ssh/client/client.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "io" "net" "os" "path/filepath" @@ -551,14 +550,15 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) { defer func() { if err := localConn.Close(); err != nil { - log.Debugf("local connection close error: %v", err) + log.Debugf("local port forwarding: close local connection: %v", err) } }() channel, err := c.client.Dial("tcp", remoteAddr) if err != nil { - if strings.Contains(err.Error(), "administratively prohibited") { - _, _ = fmt.Fprintf(os.Stderr, "channel open failed: administratively prohibited: port forwarding is disabled\n") + var openErr *ssh.OpenChannelError + if errors.As(err, &openErr) && openErr.Reason == ssh.Prohibited { + _, _ = fmt.Fprintf(os.Stderr, "channel open failed: port forwarding is disabled\n") } else { log.Debugf("local port forwarding to %s failed: %v", remoteAddr, err) } @@ -566,19 +566,11 @@ func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) { } defer func() { if err := channel.Close(); err != nil { - log.Debugf("remote channel close error: %v", err) + log.Debugf("local port forwarding: close remote channel: %v", err) } }() - go func() { - if _, err := io.Copy(channel, localConn); err != nil { - log.Debugf("local forward copy error (local->remote): %v", err) - } - }() - - if _, err := io.Copy(localConn, channel); err != nil { - log.Debugf("local forward copy error (remote->local): %v", err) - } + nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel) } // RemotePortForward sets up remote port forwarding, binding on remote and forwarding to localAddr @@ -633,7 +625,7 @@ func (c *Client) sendTCPIPForwardRequest(req tcpipForwardMsg) error { return fmt.Errorf("send tcpip-forward request: %w", err) } if !ok { - return fmt.Errorf("remote port forwarding denied by server (check if --allow-ssh-remote-port-forwarding is enabled)") + return fmt.Errorf("remote port forwarding denied by server") } return nil } @@ -676,7 +668,7 @@ func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr st } defer func() { if err := channel.Close(); err != nil { - log.Debugf("remote channel close error: %v", err) + log.Debugf("remote port forwarding: close remote channel: %v", err) } }() @@ -688,19 +680,11 @@ func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr st } defer func() { if err := localConn.Close(); err != nil { - log.Debugf("local connection close error: %v", err) + log.Debugf("remote port forwarding: close local connection: %v", err) } }() - go func() { - if _, err := io.Copy(localConn, channel); err != nil { - log.Debugf("remote forward copy error (remote->local): %v", err) - } - }() - - if _, err := io.Copy(channel, localConn); err != nil { - log.Debugf("remote forward copy error (local->remote): %v", err) - } + nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel) } // tcpipForwardMsg represents the structure for tcpip-forward requests diff --git a/client/ssh/common.go b/client/ssh/common.go index 6574437b5..f6aec5f9c 100644 --- a/client/ssh/common.go +++ b/client/ssh/common.go @@ -193,3 +193,64 @@ func buildAddressList(hostname string, remote net.Addr) []string { } return addresses } + +// BidirectionalCopy copies data bidirectionally between two io.ReadWriter connections. +// It waits for both directions to complete before returning. +// The caller is responsible for closing the connections. +func BidirectionalCopy(logger *log.Entry, rw1, rw2 io.ReadWriter) { + done := make(chan struct{}, 2) + + go func() { + if _, err := io.Copy(rw2, rw1); err != nil && !isExpectedCopyError(err) { + logger.Debugf("copy error (1->2): %v", err) + } + done <- struct{}{} + }() + + go func() { + if _, err := io.Copy(rw1, rw2); err != nil && !isExpectedCopyError(err) { + logger.Debugf("copy error (2->1): %v", err) + } + done <- struct{}{} + }() + + <-done + <-done +} + +func isExpectedCopyError(err error) bool { + return errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) +} + +// BidirectionalCopyWithContext copies data bidirectionally between two io.ReadWriteCloser connections. +// It waits for both directions to complete or for context cancellation before returning. +// Both connections are closed when the function returns. +func BidirectionalCopyWithContext(logger *log.Entry, ctx context.Context, conn1, conn2 io.ReadWriteCloser) { + done := make(chan struct{}, 2) + + go func() { + if _, err := io.Copy(conn2, conn1); err != nil && !isExpectedCopyError(err) { + logger.Debugf("copy error (1->2): %v", err) + } + done <- struct{}{} + }() + + go func() { + if _, err := io.Copy(conn1, conn2); err != nil && !isExpectedCopyError(err) { + logger.Debugf("copy error (2->1): %v", err) + } + done <- struct{}{} + }() + + select { + case <-ctx.Done(): + case <-done: + select { + case <-ctx.Done(): + case <-done: + } + } + + _ = conn1.Close() + _ = conn2.Close() +} diff --git a/client/ssh/proxy/proxy.go b/client/ssh/proxy/proxy.go index 4e807e33c..cb1c36e13 100644 --- a/client/ssh/proxy/proxy.go +++ b/client/ssh/proxy/proxy.go @@ -2,6 +2,7 @@ package proxy import ( "context" + "encoding/binary" "errors" "fmt" "io" @@ -42,6 +43,14 @@ type SSHProxy struct { conn *grpc.ClientConn daemonClient proto.DaemonServiceClient browserOpener func(string) error + + mu sync.RWMutex + backendClient *cryptossh.Client + // jwtToken is set once in runProxySSHServer before any handlers are called, + // so concurrent access is safe without additional synchronization. + jwtToken string + + forwardedChannelsOnce sync.Once } func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer, browserOpener func(string) error) (*SSHProxy, error) { @@ -63,6 +72,17 @@ func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer, browse } func (p *SSHProxy) Close() error { + p.mu.Lock() + backendClient := p.backendClient + p.backendClient = nil + p.mu.Unlock() + + if backendClient != nil { + if err := backendClient.Close(); err != nil { + log.Debugf("close backend client: %v", err) + } + } + if p.conn != nil { return p.conn.Close() } @@ -77,16 +97,16 @@ func (p *SSHProxy) Connect(ctx context.Context) error { return fmt.Errorf(jwtAuthErrorMsg, err) } - return p.runProxySSHServer(ctx, jwtToken) + log.Debugf("JWT authentication successful, starting proxy to %s:%d", p.targetHost, p.targetPort) + return p.runProxySSHServer(jwtToken) } -func (p *SSHProxy) runProxySSHServer(ctx context.Context, jwtToken string) error { +func (p *SSHProxy) runProxySSHServer(jwtToken string) error { + p.jwtToken = jwtToken serverVersion := fmt.Sprintf("%s-%s", detection.ProxyIdentifier, version.NetbirdVersion()) sshServer := &ssh.Server{ - Handler: func(s ssh.Session) { - p.handleSSHSession(ctx, s, jwtToken) - }, + Handler: p.handleSSHSession, ChannelHandlers: map[string]ssh.ChannelHandler{ "session": ssh.DefaultSessionHandler, "direct-tcpip": p.directTCPIPHandler, @@ -119,15 +139,20 @@ func (p *SSHProxy) runProxySSHServer(ctx context.Context, jwtToken string) error return nil } -func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jwtToken string) { - targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort)) +func (p *SSHProxy) handleSSHSession(session ssh.Session) { + ptyReq, winCh, isPty := session.Pty() + hasCommand := len(session.Command()) > 0 - sshClient, err := p.dialBackend(ctx, targetAddr, session.User(), jwtToken) + sshClient, err := p.getOrCreateBackendClient(session.Context(), session.User()) if err != nil { _, _ = fmt.Fprintf(p.stderr, "SSH connection to NetBird server failed: %v\n", err) return } - defer func() { _ = sshClient.Close() }() + + if !isPty && !hasCommand { + p.handleNonInteractiveSession(session, sshClient) + return + } serverSession, err := sshClient.NewSession() if err != nil { @@ -140,7 +165,6 @@ func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jw serverSession.Stdout = session serverSession.Stderr = session.Stderr() - ptyReq, winCh, isPty := session.Pty() if isPty { if err := serverSession.RequestPty(ptyReq.Term, ptyReq.Window.Width, ptyReq.Window.Height, nil); err != nil { log.Debugf("PTY request to backend: %v", err) @@ -155,7 +179,7 @@ func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jw }() } - if len(session.Command()) > 0 { + if hasCommand { if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil { log.Debugf("run command: %v", err) p.handleProxyExitCode(session, err) @@ -176,12 +200,29 @@ func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jw func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) { var exitErr *cryptossh.ExitError if errors.As(err, &exitErr) { - if exitErr := session.Exit(exitErr.ExitStatus()); exitErr != nil { - log.Debugf("set exit status: %v", exitErr) + if err := session.Exit(exitErr.ExitStatus()); err != nil { + log.Debugf("set exit status: %v", err) } } } +func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *cryptossh.Client) { + // Create a backend session to mirror the client's session request. + // This keeps the connection alive on the server side while port forwarding channels operate. + serverSession, err := sshClient.NewSession() + if err != nil { + _, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err) + return + } + defer func() { _ = serverSession.Close() }() + + <-session.Context().Done() + + if err := session.Exit(0); err != nil { + log.Debugf("session exit: %v", err) + } +} + func generateHostKey() (ssh.Signer, error) { keyPEM, err := nbssh.GeneratePrivateKey(nbssh.ED25519) if err != nil { @@ -250,8 +291,52 @@ func (c *stdioConn) SetWriteDeadline(_ time.Time) error { return nil } -func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, newChan cryptossh.NewChannel, _ ssh.Context) { - _ = newChan.Reject(cryptossh.Prohibited, "port forwarding not supported in proxy") +// directTCPIPHandler handles local port forwarding (direct-tcpip channel). +func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, newChan cryptossh.NewChannel, sshCtx ssh.Context) { + var payload struct { + DestAddr string + DestPort uint32 + OriginAddr string + OriginPort uint32 + } + if err := cryptossh.Unmarshal(newChan.ExtraData(), &payload); err != nil { + _, _ = fmt.Fprintf(p.stderr, "parse direct-tcpip payload: %v\n", err) + _ = newChan.Reject(cryptossh.ConnectionFailed, "invalid payload") + return + } + + dest := fmt.Sprintf("%s:%d", payload.DestAddr, payload.DestPort) + log.Debugf("local port forwarding: %s", dest) + + backendClient, err := p.getOrCreateBackendClient(sshCtx, sshCtx.User()) + if err != nil { + _, _ = fmt.Fprintf(p.stderr, "backend connection for port forwarding: %v\n", err) + _ = newChan.Reject(cryptossh.ConnectionFailed, "backend connection failed") + return + } + + backendChan, backendReqs, err := backendClient.OpenChannel("direct-tcpip", newChan.ExtraData()) + if err != nil { + _, _ = fmt.Fprintf(p.stderr, "open backend channel for %s: %v\n", dest, err) + var openErr *cryptossh.OpenChannelError + if errors.As(err, &openErr) { + _ = newChan.Reject(openErr.Reason, openErr.Message) + } else { + _ = newChan.Reject(cryptossh.ConnectionFailed, err.Error()) + } + return + } + go cryptossh.DiscardRequests(backendReqs) + + clientChan, clientReqs, err := newChan.Accept() + if err != nil { + log.Debugf("local port forwarding: accept channel: %v", err) + _ = backendChan.Close() + return + } + go cryptossh.DiscardRequests(clientReqs) + + nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan) } func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) { @@ -354,12 +439,143 @@ func (p *SSHProxy) runSFTPBridge(ctx context.Context, s ssh.Session, stdin io.Wr } } -func (p *SSHProxy) tcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) { - return false, []byte("port forwarding not supported in proxy") +// tcpipForwardHandler handles remote port forwarding (tcpip-forward request). +func (p *SSHProxy) tcpipForwardHandler(sshCtx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) { + var reqPayload struct { + Host string + Port uint32 + } + if err := cryptossh.Unmarshal(req.Payload, &reqPayload); err != nil { + _, _ = fmt.Fprintf(p.stderr, "parse tcpip-forward payload: %v\n", err) + return false, nil + } + + log.Debugf("tcpip-forward request for %s:%d", reqPayload.Host, reqPayload.Port) + + backendClient, err := p.getOrCreateBackendClient(sshCtx, sshCtx.User()) + if err != nil { + _, _ = fmt.Fprintf(p.stderr, "backend connection for remote port forwarding: %v\n", err) + return false, nil + } + + ok, payload, err := backendClient.SendRequest(req.Type, req.WantReply, req.Payload) + if err != nil { + _, _ = fmt.Fprintf(p.stderr, "forward tcpip-forward request for %s:%d: %v\n", reqPayload.Host, reqPayload.Port, err) + return false, nil + } + + if ok { + actualPort := reqPayload.Port + if reqPayload.Port == 0 && len(payload) >= 4 { + actualPort = binary.BigEndian.Uint32(payload) + } + log.Debugf("remote port forwarding established for %s:%d", reqPayload.Host, actualPort) + p.forwardedChannelsOnce.Do(func() { + go p.handleForwardedChannels(sshCtx, backendClient) + }) + } + + return ok, payload } -func (p *SSHProxy) cancelTcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) { - return true, nil +// cancelTcpipForwardHandler handles cancel-tcpip-forward request. +func (p *SSHProxy) cancelTcpipForwardHandler(_ ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) { + var reqPayload struct { + Host string + Port uint32 + } + if err := cryptossh.Unmarshal(req.Payload, &reqPayload); err != nil { + _, _ = fmt.Fprintf(p.stderr, "parse cancel-tcpip-forward payload: %v\n", err) + return false, nil + } + + log.Debugf("cancel-tcpip-forward request for %s:%d", reqPayload.Host, reqPayload.Port) + + backendClient := p.getBackendClient() + if backendClient == nil { + return false, nil + } + + ok, payload, err := backendClient.SendRequest(req.Type, req.WantReply, req.Payload) + if err != nil { + _, _ = fmt.Fprintf(p.stderr, "cancel-tcpip-forward for %s:%d: %v\n", reqPayload.Host, reqPayload.Port, err) + return false, nil + } + + return ok, payload +} + +// getOrCreateBackendClient returns the existing backend client or creates a new one. +func (p *SSHProxy) getOrCreateBackendClient(ctx context.Context, user string) (*cryptossh.Client, error) { + p.mu.Lock() + defer p.mu.Unlock() + + if p.backendClient != nil { + return p.backendClient, nil + } + + targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort)) + log.Debugf("connecting to backend %s", targetAddr) + + client, err := p.dialBackend(ctx, targetAddr, user, p.jwtToken) + if err != nil { + return nil, err + } + + log.Debugf("backend connection established to %s", targetAddr) + p.backendClient = client + return client, nil +} + +// getBackendClient returns the existing backend client or nil. +func (p *SSHProxy) getBackendClient() *cryptossh.Client { + p.mu.RLock() + defer p.mu.RUnlock() + return p.backendClient +} + +// handleForwardedChannels handles forwarded-tcpip channels from the backend for remote port forwarding. +// When the backend receives incoming connections on the forwarded port, it sends them as +// "forwarded-tcpip" channels which we need to proxy to the client. +func (p *SSHProxy) handleForwardedChannels(sshCtx ssh.Context, backendClient *cryptossh.Client) { + sshConn, ok := sshCtx.Value(ssh.ContextKeyConn).(*cryptossh.ServerConn) + if !ok || sshConn == nil { + log.Debugf("no SSH connection in context for forwarded channels") + return + } + + channelChan := backendClient.HandleChannelOpen("forwarded-tcpip") + for { + select { + case <-sshCtx.Done(): + return + case newChannel, ok := <-channelChan: + if !ok { + return + } + go p.handleForwardedChannel(sshCtx, sshConn, newChannel) + } + } +} + +// handleForwardedChannel handles a single forwarded-tcpip channel from the backend. +func (p *SSHProxy) handleForwardedChannel(sshCtx ssh.Context, sshConn *cryptossh.ServerConn, newChannel cryptossh.NewChannel) { + backendChan, backendReqs, err := newChannel.Accept() + if err != nil { + log.Debugf("remote port forwarding: accept from backend: %v", err) + return + } + go cryptossh.DiscardRequests(backendReqs) + + clientChan, clientReqs, err := sshConn.OpenChannel("forwarded-tcpip", newChannel.ExtraData()) + if err != nil { + log.Debugf("remote port forwarding: open to client: %v", err) + _ = backendChan.Close() + return + } + go cryptossh.DiscardRequests(clientReqs) + + nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan) } func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) { diff --git a/client/ssh/proxy/proxy_test.go b/client/ssh/proxy/proxy_test.go index 582f9c07b..81d588801 100644 --- a/client/ssh/proxy/proxy_test.go +++ b/client/ssh/proxy/proxy_test.go @@ -27,9 +27,11 @@ import ( "github.com/netbirdio/netbird/client/proto" nbssh "github.com/netbirdio/netbird/client/ssh" + sshauth "github.com/netbirdio/netbird/client/ssh/auth" "github.com/netbirdio/netbird/client/ssh/server" "github.com/netbirdio/netbird/client/ssh/testutil" nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" + sshuserhash "github.com/netbirdio/netbird/shared/sshauth" ) func TestMain(m *testing.M) { @@ -137,6 +139,21 @@ func TestSSHProxy_Connect(t *testing.T) { sshServer := server.New(serverConfig) sshServer.SetAllowRootLogin(true) + // Configure SSH authorization for the test user + testUsername := testutil.GetTestUsername(t) + testJWTUser := "test-username" + testUserHash, err := sshuserhash.HashUserID(testJWTUser) + require.NoError(t, err) + + authConfig := &sshauth.Config{ + UserIDClaim: sshauth.DefaultUserIDClaim, + AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash}, + MachineUsers: map[string][]uint32{ + testUsername: {0}, // Index 0 in AuthorizedUsers + }, + } + sshServer.UpdateSSHAuth(authConfig) + sshServerAddr := server.StartTestServer(t, sshServer) defer func() { _ = sshServer.Stop() }() @@ -150,10 +167,10 @@ func TestSSHProxy_Connect(t *testing.T) { mockDaemon.setHostKey(host, hostPubKey) - validToken := generateValidJWT(t, privateKey, issuer, audience) + validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser) mockDaemon.setJWTToken(validToken) - proxyInstance, err := New(mockDaemon.addr, host, port, nil, nil) + proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil) require.NoError(t, err) clientConn, proxyConn := net.Pipe() @@ -347,12 +364,12 @@ func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) { return privateKey, jwksJSON } -func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string { +func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string, user string) string { t.Helper() claims := jwt.MapClaims{ "iss": issuer, "aud": audience, - "sub": "test-user", + "sub": user, "exp": time.Now().Add(time.Hour).Unix(), "iat": time.Now().Unix(), } diff --git a/client/ssh/server/jwt_test.go b/client/ssh/server/jwt_test.go index 1f3bac76d..d36d7cbbf 100644 --- a/client/ssh/server/jwt_test.go +++ b/client/ssh/server/jwt_test.go @@ -23,10 +23,12 @@ import ( "github.com/stretchr/testify/require" nbssh "github.com/netbirdio/netbird/client/ssh" + sshauth "github.com/netbirdio/netbird/client/ssh/auth" "github.com/netbirdio/netbird/client/ssh/client" "github.com/netbirdio/netbird/client/ssh/detection" "github.com/netbirdio/netbird/client/ssh/testutil" nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" + sshuserhash "github.com/netbirdio/netbird/shared/sshauth" ) func TestJWTEnforcement(t *testing.T) { @@ -577,6 +579,22 @@ func TestJWTAuthentication(t *testing.T) { tc.setupServer(server) } + // Always set up authorization for test-user to ensure tests fail at JWT validation stage + testUserHash, err := sshuserhash.HashUserID("test-user") + require.NoError(t, err) + + // Get current OS username for machine user mapping + currentUser := testutil.GetTestUsername(t) + + authConfig := &sshauth.Config{ + UserIDClaim: sshauth.DefaultUserIDClaim, + AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash}, + MachineUsers: map[string][]uint32{ + currentUser: {0}, // Allow test-user (index 0) to access current OS user + }, + } + server.UpdateSSHAuth(authConfig) + serverAddr := StartTestServer(t, server) defer require.NoError(t, server.Stop()) diff --git a/client/ssh/server/port_forwarding.go b/client/ssh/server/port_forwarding.go index 6138f9296..c60cf4f58 100644 --- a/client/ssh/server/port_forwarding.go +++ b/client/ssh/server/port_forwarding.go @@ -1,25 +1,32 @@ +// Package server implements port forwarding for the SSH server. +// +// Security note: Port forwarding runs in the main server process without privilege separation. +// The attack surface is primarily io.Copy through well-tested standard library code, making it +// lower risk than shell execution which uses privilege-separated child processes. We enforce +// user-level port restrictions: non-privileged users cannot bind to ports < 1024. package server import ( "encoding/binary" "fmt" - "io" "net" + "runtime" "strconv" "github.com/gliderlabs/ssh" log "github.com/sirupsen/logrus" cryptossh "golang.org/x/crypto/ssh" + + nbssh "github.com/netbirdio/netbird/client/ssh" ) -// SessionKey uniquely identifies an SSH session -type SessionKey string +const privilegedPortThreshold = 1024 -// ConnectionKey uniquely identifies a port forwarding connection within a session -type ConnectionKey string +// sessionKey uniquely identifies an SSH session +type sessionKey string -// ForwardKey uniquely identifies a port forwarding listener -type ForwardKey string +// forwardKey uniquely identifies a port forwarding listener +type forwardKey string // tcpipForwardMsg represents the structure for tcpip-forward SSH requests type tcpipForwardMsg struct { @@ -47,34 +54,32 @@ func (s *Server) configurePortForwarding(server *ssh.Server) { allowRemote := s.allowRemotePortForwarding server.LocalPortForwardingCallback = func(ctx ssh.Context, dstHost string, dstPort uint32) bool { + logger := s.getRequestLogger(ctx) if !allowLocal { - log.Warnf("local port forwarding denied for %s from %s: disabled by configuration", - net.JoinHostPort(dstHost, fmt.Sprintf("%d", dstPort)), ctx.RemoteAddr()) + logger.Warnf("local port forwarding denied for %s:%d: disabled", dstHost, dstPort) return false } if err := s.checkPortForwardingPrivileges(ctx, "local", dstPort); err != nil { - log.Warnf("local port forwarding denied for %s:%d from %s: %v", dstHost, dstPort, ctx.RemoteAddr(), err) + logger.Warnf("local port forwarding denied for %s:%d: %v", dstHost, dstPort, err) return false } - log.Debugf("local port forwarding allowed: %s:%d", dstHost, dstPort) return true } server.ReversePortForwardingCallback = func(ctx ssh.Context, bindHost string, bindPort uint32) bool { + logger := s.getRequestLogger(ctx) if !allowRemote { - log.Warnf("remote port forwarding denied for %s from %s: disabled by configuration", - net.JoinHostPort(bindHost, fmt.Sprintf("%d", bindPort)), ctx.RemoteAddr()) + logger.Warnf("remote port forwarding denied for %s:%d: disabled", bindHost, bindPort) return false } if err := s.checkPortForwardingPrivileges(ctx, "remote", bindPort); err != nil { - log.Warnf("remote port forwarding denied for %s:%d from %s: %v", bindHost, bindPort, ctx.RemoteAddr(), err) + logger.Warnf("remote port forwarding denied for %s:%d: %v", bindHost, bindPort, err) return false } - log.Debugf("remote port forwarding allowed: %s:%d", bindHost, bindPort) return true } @@ -82,23 +87,20 @@ func (s *Server) configurePortForwarding(server *ssh.Server) { } // checkPortForwardingPrivileges validates privilege requirements for port forwarding operations. -// Returns nil if allowed, error if denied. +// For remote port forwarding (binding), it enforces that non-privileged users cannot bind to +// ports below 1024, mirroring the restriction they would face if binding directly. +// +// Note: FeatureSupportsUserSwitch is true because we accept requests from any authenticated user, +// though we don't actually switch users - port forwarding runs in the server process. The resolved +// user is used for privileged port access checks. func (s *Server) checkPortForwardingPrivileges(ctx ssh.Context, forwardType string, port uint32) error { if ctx == nil { return fmt.Errorf("%s port forwarding denied: no context", forwardType) } - username := ctx.User() - remoteAddr := "unknown" - if ctx.RemoteAddr() != nil { - remoteAddr = ctx.RemoteAddr().String() - } - - logger := log.WithFields(log.Fields{"user": username, "remote": remoteAddr, "port": port}) - result := s.CheckPrivileges(PrivilegeCheckRequest{ - RequestedUsername: username, - FeatureSupportsUserSwitch: false, + RequestedUsername: ctx.User(), + FeatureSupportsUserSwitch: true, FeatureName: forwardType + " port forwarding", }) @@ -106,12 +108,42 @@ func (s *Server) checkPortForwardingPrivileges(ctx ssh.Context, forwardType stri return result.Error } - logger.Debugf("%s port forwarding allowed: user %s validated (port %d)", - forwardType, result.User.Username, port) + if err := s.checkPrivilegedPortAccess(forwardType, port, result); err != nil { + return err + } return nil } +// checkPrivilegedPortAccess enforces that non-privileged users cannot bind to privileged ports. +// This applies to remote port forwarding where the server binds a port on behalf of the user. +// On Windows, there is no privileged port restriction, so this check is skipped. +func (s *Server) checkPrivilegedPortAccess(forwardType string, port uint32, result PrivilegeCheckResult) error { + if runtime.GOOS == "windows" { + return nil + } + + isBindOperation := forwardType == "remote" || forwardType == "tcpip-forward" + if !isBindOperation { + return nil + } + + // Port 0 means "pick any available port", which will be >= 1024 + if port == 0 || port >= privilegedPortThreshold { + return nil + } + + if result.User != nil && isPrivilegedUsername(result.User.Username) { + return nil + } + + username := "unknown" + if result.User != nil { + username = result.User.Username + } + return fmt.Errorf("user %s cannot bind to privileged port %d (requires root)", username, port) +} + // tcpipForwardHandler handles tcpip-forward requests for remote port forwarding. func (s *Server) tcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) { logger := s.getRequestLogger(ctx) @@ -132,8 +164,6 @@ func (s *Server) tcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *crypto return false, nil } - logger.Debugf("tcpip-forward request: %s:%d", payload.Host, payload.Port) - sshConn, err := s.getSSHConnection(ctx) if err != nil { logger.Warnf("tcpip-forward request denied: %v", err) @@ -153,8 +183,10 @@ func (s *Server) cancelTcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req * return false, nil } - key := ForwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port)) + key := forwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port)) if s.removeRemoteForwardListener(key) { + forwardAddr := fmt.Sprintf("-R %s:%d", payload.Host, payload.Port) + s.removeConnectionPortForward(ctx.RemoteAddr(), forwardAddr) logger.Infof("remote port forwarding cancelled: %s:%d", payload.Host, payload.Port) return true, nil } @@ -165,14 +197,11 @@ func (s *Server) cancelTcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req * // handleRemoteForwardListener handles incoming connections for remote port forwarding. func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, host string, port uint32) { - log.Debugf("starting remote forward listener handler for %s:%d", host, port) + logger := s.getRequestLogger(ctx) defer func() { - log.Debugf("cleaning up remote forward listener for %s:%d", host, port) if err := ln.Close(); err != nil { - log.Debugf("remote forward listener close error: %v", err) - } else { - log.Debugf("remote forward listener closed successfully for %s:%d", host, port) + logger.Debugf("remote forward listener close error for %s:%d: %v", host, port, err) } }() @@ -196,28 +225,43 @@ func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, h select { case result := <-acceptChan: if result.err != nil { - log.Debugf("remote forward accept error: %v", result.err) + logger.Debugf("remote forward accept error: %v", result.err) return } go s.handleRemoteForwardConnection(ctx, result.conn, host, port) case <-ctx.Done(): - log.Debugf("remote forward listener shutting down due to context cancellation for %s:%d", host, port) + logger.Debugf("remote forward listener shutting down for %s:%d", host, port) return } } } -// getRequestLogger creates a logger with user and remote address context +// getRequestLogger creates a logger with session/conn and jwt_user context func (s *Server) getRequestLogger(ctx ssh.Context) *log.Entry { - remoteAddr := "unknown" - username := "unknown" - if ctx != nil { - if ctx.RemoteAddr() != nil { - remoteAddr = ctx.RemoteAddr().String() + sessionKey := s.findSessionKeyByContext(ctx) + + s.mu.RLock() + defer s.mu.RUnlock() + + if state, exists := s.sessions[sessionKey]; exists { + logger := log.WithField("session", sessionKey) + if state.jwtUsername != "" { + logger = logger.WithField("jwt_user", state.jwtUsername) } - username = ctx.User() + return logger } - return log.WithFields(log.Fields{"user": username, "remote": remoteAddr}) + + if ctx.RemoteAddr() != nil { + if connState, exists := s.connections[connKey(ctx.RemoteAddr().String())]; exists { + return s.connLogger(connState) + } + } + + remoteAddr := "unknown" + if ctx.RemoteAddr() != nil { + remoteAddr = ctx.RemoteAddr().String() + } + return log.WithField("session", fmt.Sprintf("%s@%s", ctx.User(), remoteAddr)) } // isRemotePortForwardingAllowed checks if remote port forwarding is enabled @@ -227,6 +271,13 @@ func (s *Server) isRemotePortForwardingAllowed() bool { return s.allowRemotePortForwarding } +// isPortForwardingEnabled checks if any port forwarding (local or remote) is enabled +func (s *Server) isPortForwardingEnabled() bool { + s.mu.RLock() + defer s.mu.RUnlock() + return s.allowLocalPortForwarding || s.allowRemotePortForwarding +} + // parseTcpipForwardRequest parses the SSH request payload func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) { var payload tcpipForwardMsg @@ -267,10 +318,11 @@ func (s *Server) setupDirectForward(ctx ssh.Context, logger *log.Entry, sshConn logger.Debugf("tcpip-forward allocated port %d for %s", actualPort, payload.Host) } - key := ForwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port)) + key := forwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port)) s.storeRemoteForwardListener(key, ln) - s.markConnectionActivePortForward(sshConn, ctx.User(), ctx.RemoteAddr().String()) + forwardAddr := fmt.Sprintf("-R %s:%d", payload.Host, actualPort) + s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr) go s.handleRemoteForwardListener(ctx, ln, payload.Host, actualPort) response := make([]byte, 4) @@ -288,44 +340,34 @@ type acceptResult struct { // handleRemoteForwardConnection handles a single remote port forwarding connection func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, host string, port uint32) { - sessionKey := s.findSessionKeyByContext(ctx) - connID := fmt.Sprintf("pf-%s->%s:%d", conn.RemoteAddr(), host, port) - logger := log.WithFields(log.Fields{ - "session": sessionKey, - "conn": connID, - }) + logger := s.getRequestLogger(ctx) - defer func() { - if err := conn.Close(); err != nil { - logger.Debugf("connection close error: %v", err) - } - }() - - sshConn := ctx.Value(ssh.ContextKeyConn).(*cryptossh.ServerConn) - if sshConn == nil { + sshConn, ok := ctx.Value(ssh.ContextKeyConn).(*cryptossh.ServerConn) + if !ok || sshConn == nil { logger.Debugf("remote forward: no SSH connection in context") + _ = conn.Close() return } remoteAddr, ok := conn.RemoteAddr().(*net.TCPAddr) if !ok { logger.Warnf("remote forward: non-TCP connection type: %T", conn.RemoteAddr()) + _ = conn.Close() return } - channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr, logger) + channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr) if err != nil { - logger.Debugf("open forward channel: %v", err) + logger.Debugf("open forward channel for %s:%d: %v", host, port, err) + _ = conn.Close() return } - s.proxyForwardConnection(ctx, logger, conn, channel) + nbssh.BidirectionalCopyWithContext(logger, ctx, conn, channel) } // openForwardChannel creates an SSH forwarded-tcpip channel -func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string, port uint32, remoteAddr *net.TCPAddr, logger *log.Entry) (cryptossh.Channel, error) { - logger.Tracef("opening forwarded-tcpip channel for %s:%d", host, port) - +func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string, port uint32, remoteAddr *net.TCPAddr) (cryptossh.Channel, error) { payload := struct { ConnectedAddress string ConnectedPort uint32 @@ -346,41 +388,3 @@ func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string, go cryptossh.DiscardRequests(reqs) return channel, nil } - -// proxyForwardConnection handles bidirectional data transfer between connection and SSH channel -func (s *Server) proxyForwardConnection(ctx ssh.Context, logger *log.Entry, conn net.Conn, channel cryptossh.Channel) { - done := make(chan struct{}, 2) - - go func() { - if _, err := io.Copy(channel, conn); err != nil { - logger.Debugf("copy error (conn->channel): %v", err) - } - done <- struct{}{} - }() - - go func() { - if _, err := io.Copy(conn, channel); err != nil { - logger.Debugf("copy error (channel->conn): %v", err) - } - done <- struct{}{} - }() - - select { - case <-ctx.Done(): - logger.Debugf("session ended, closing connections") - case <-done: - // First copy finished, wait for second copy or context cancellation - select { - case <-ctx.Done(): - logger.Debugf("session ended, closing connections") - case <-done: - } - } - - if err := channel.Close(); err != nil { - logger.Debugf("channel close error: %v", err) - } - if err := conn.Close(); err != nil { - logger.Debugf("connection close error: %v", err) - } -} diff --git a/client/ssh/server/server.go b/client/ssh/server/server.go index 37763ee0e..3a8568979 100644 --- a/client/ssh/server/server.go +++ b/client/ssh/server/server.go @@ -9,6 +9,7 @@ import ( "io" "net" "net/netip" + "slices" "strings" "sync" "time" @@ -21,6 +22,7 @@ import ( "golang.zx2c4.com/wireguard/tun/netstack" "github.com/netbirdio/netbird/client/iface/wgaddr" + sshauth "github.com/netbirdio/netbird/client/ssh/auth" "github.com/netbirdio/netbird/client/ssh/detection" "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/auth/jwt" @@ -39,6 +41,11 @@ const ( msgPrivilegedUserDisabled = "privileged user login is disabled" + cmdInteractiveShell = "" + cmdPortForwarding = "" + cmdSFTP = "" + cmdNonInteractive = "" + // DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server DefaultJWTMaxTokenAge = 5 * 60 ) @@ -89,10 +96,10 @@ func logSessionExitError(logger *log.Entry, err error) { } } -// safeLogCommand returns a safe representation of the command for logging +// safeLogCommand returns a safe representation of the command for logging. func safeLogCommand(cmd []string) string { if len(cmd) == 0 { - return "" + return cmdInteractiveShell } if len(cmd) == 1 { return cmd[0] @@ -100,26 +107,50 @@ func safeLogCommand(cmd []string) string { return fmt.Sprintf("%s [%d args]", cmd[0], len(cmd)-1) } -type sshConnectionState struct { - hasActivePortForward bool - username string - remoteAddr string +// connState tracks the state of an SSH connection for port forwarding and status display. +type connState struct { + username string + remoteAddr net.Addr + portForwards []string + jwtUsername string } +// authKey uniquely identifies an authentication attempt by username and remote address. +// Used to temporarily store JWT username between passwordHandler and sessionHandler. type authKey string +// connKey uniquely identifies an SSH connection by its remote address. +// Used to track authenticated connections for status display and port forwarding. +type connKey string + func newAuthKey(username string, remoteAddr net.Addr) authKey { return authKey(fmt.Sprintf("%s@%s", username, remoteAddr.String())) } +// sessionState tracks an active SSH session (shell, command, or subsystem like SFTP). +type sessionState struct { + session ssh.Session + sessionType string + jwtUsername string +} + type Server struct { - sshServer *ssh.Server - mu sync.RWMutex - hostKeyPEM []byte - sessions map[SessionKey]ssh.Session - sessionCancels map[ConnectionKey]context.CancelFunc - sessionJWTUsers map[SessionKey]string - pendingAuthJWT map[authKey]string + sshServer *ssh.Server + listener net.Listener + mu sync.RWMutex + hostKeyPEM []byte + + // sessions tracks active SSH sessions (shell, command, SFTP). + // These are created when a client opens a session channel and requests shell/exec/subsystem. + sessions map[sessionKey]*sessionState + + // pendingAuthJWT temporarily stores JWT username during the auth→session handoff. + // Populated in passwordHandler, consumed in sessionHandler/sftpSubsystemHandler. + pendingAuthJWT map[authKey]string + + // connections tracks all SSH connections by their remote address. + // Populated at authentication time, stores JWT username and port forwards for status display. + connections map[connKey]*connState allowLocalPortForwarding bool allowRemotePortForwarding bool @@ -131,13 +162,14 @@ type Server struct { wgAddress wgaddr.Address - remoteForwardListeners map[ForwardKey]net.Listener - sshConnections map[*cryptossh.ServerConn]*sshConnectionState + remoteForwardListeners map[forwardKey]net.Listener jwtValidator *jwt.Validator jwtExtractor *jwt.ClaimsExtractor jwtConfig *JWTConfig + authorizer *sshauth.Authorizer + suSupportsPty bool loginIsUtilLinux bool } @@ -164,6 +196,7 @@ type SessionInfo struct { RemoteAddress string Command string JWTUsername string + PortForwards []string } // New creates an SSH server instance with the provided host key and optional JWT configuration @@ -172,13 +205,13 @@ func New(config *Config) *Server { s := &Server{ mu: sync.RWMutex{}, hostKeyPEM: config.HostKeyPEM, - sessions: make(map[SessionKey]ssh.Session), - sessionJWTUsers: make(map[SessionKey]string), + sessions: make(map[sessionKey]*sessionState), pendingAuthJWT: make(map[authKey]string), - remoteForwardListeners: make(map[ForwardKey]net.Listener), - sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState), + remoteForwardListeners: make(map[forwardKey]net.Listener), + connections: make(map[connKey]*connState), jwtEnabled: config.JWT != nil, jwtConfig: config.JWT, + authorizer: sshauth.NewAuthorizer(), // Initialize with empty config } return s @@ -207,6 +240,7 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error { return fmt.Errorf("create SSH server: %w", err) } + s.listener = ln s.sshServer = sshServer log.Infof("SSH server started on %s", addrDesc) @@ -259,16 +293,11 @@ func (s *Server) Stop() error { } s.sshServer = nil + s.listener = nil maps.Clear(s.sessions) - maps.Clear(s.sessionJWTUsers) maps.Clear(s.pendingAuthJWT) - maps.Clear(s.sshConnections) - - for _, cancelFunc := range s.sessionCancels { - cancelFunc() - } - maps.Clear(s.sessionCancels) + maps.Clear(s.connections) for _, listener := range s.remoteForwardListeners { if err := listener.Close(); err != nil { @@ -280,32 +309,82 @@ func (s *Server) Stop() error { return nil } -// GetStatus returns the current status of the SSH server and active sessions +// Addr returns the address the SSH server is listening on, or nil if the server is not running +func (s *Server) Addr() net.Addr { + s.mu.RLock() + defer s.mu.RUnlock() + + if s.listener == nil { + return nil + } + + return s.listener.Addr() +} + +// GetStatus returns the current status of the SSH server and active sessions. func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) { s.mu.RLock() defer s.mu.RUnlock() enabled = s.sshServer != nil + reportedAddrs := make(map[string]bool) - for sessionKey, session := range s.sessions { - cmd := "" - if len(session.Command()) > 0 { - cmd = safeLogCommand(session.Command()) + for _, state := range s.sessions { + info := s.buildSessionInfo(state) + reportedAddrs[info.RemoteAddress] = true + sessions = append(sessions, info) + } + + // Add authenticated connections without sessions (e.g., -N/-T or port-forwarding only) + for key, connState := range s.connections { + remoteAddr := string(key) + if reportedAddrs[remoteAddr] { + continue + } + cmd := cmdNonInteractive + if len(connState.portForwards) > 0 { + cmd = cmdPortForwarding } - - jwtUsername := s.sessionJWTUsers[sessionKey] - sessions = append(sessions, SessionInfo{ - Username: session.User(), - RemoteAddress: session.RemoteAddr().String(), + Username: connState.username, + RemoteAddress: remoteAddr, Command: cmd, - JWTUsername: jwtUsername, + JWTUsername: connState.jwtUsername, + PortForwards: connState.portForwards, }) } return enabled, sessions } +func (s *Server) buildSessionInfo(state *sessionState) SessionInfo { + session := state.session + cmd := state.sessionType + if cmd == "" { + cmd = safeLogCommand(session.Command()) + } + + remoteAddr := session.RemoteAddr().String() + info := SessionInfo{ + Username: session.User(), + RemoteAddress: remoteAddr, + Command: cmd, + JWTUsername: state.jwtUsername, + } + + connState, exists := s.connections[connKey(remoteAddr)] + if !exists { + return info + } + + info.PortForwards = connState.portForwards + if len(connState.portForwards) > 0 && (cmd == cmdInteractiveShell || cmd == cmdNonInteractive) { + info.Command = cmdPortForwarding + } + + return info +} + // SetNetstackNet sets the netstack network for userspace networking func (s *Server) SetNetstackNet(net *netstack.Net) { s.mu.Lock() @@ -320,6 +399,19 @@ func (s *Server) SetNetworkValidation(addr wgaddr.Address) { s.wgAddress = addr } +// UpdateSSHAuth updates the SSH fine-grained access control configuration +// This should be called when network map updates include new SSH auth configuration +func (s *Server) UpdateSSHAuth(config *sshauth.Config) { + s.mu.Lock() + defer s.mu.Unlock() + + // Reset JWT validator/extractor to pick up new userIDClaim + s.jwtValidator = nil + s.jwtExtractor = nil + + s.authorizer.Update(config) +} + // ensureJWTValidator initializes the JWT validator and extractor if not already initialized func (s *Server) ensureJWTValidator() error { s.mu.RLock() @@ -328,6 +420,7 @@ func (s *Server) ensureJWTValidator() error { return nil } config := s.jwtConfig + authorizer := s.authorizer s.mu.RUnlock() if config == nil { @@ -343,9 +436,16 @@ func (s *Server) ensureJWTValidator() error { true, ) - extractor := jwt.NewClaimsExtractor( + // Use custom userIDClaim from authorizer if available + extractorOptions := []jwt.ClaimsExtractorOption{ jwt.WithAudience(config.Audience), - ) + } + if authorizer.GetUserIDClaim() != "" { + extractorOptions = append(extractorOptions, jwt.WithUserIDClaim(authorizer.GetUserIDClaim())) + log.Debugf("Using custom user ID claim: %s", authorizer.GetUserIDClaim()) + } + + extractor := jwt.NewClaimsExtractor(extractorOptions...) s.mu.Lock() defer s.mu.Unlock() @@ -493,59 +593,131 @@ func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]int } func (s *Server) passwordHandler(ctx ssh.Context, password string) bool { + osUsername := ctx.User() + remoteAddr := ctx.RemoteAddr() + logger := s.getRequestLogger(ctx) + if err := s.ensureJWTValidator(); err != nil { - log.Errorf("JWT validator initialization failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err) + logger.Errorf("JWT validator initialization failed: %v", err) return false } token, err := s.validateJWTToken(password) if err != nil { - log.Warnf("JWT authentication failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err) + logger.Warnf("JWT authentication failed: %v", err) return false } userAuth, err := s.extractAndValidateUser(token) if err != nil { - log.Warnf("User validation failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err) + logger.Warnf("user validation failed: %v", err) return false } - key := newAuthKey(ctx.User(), ctx.RemoteAddr()) + logger = logger.WithField("jwt_user", userAuth.UserId) + + s.mu.RLock() + authorizer := s.authorizer + s.mu.RUnlock() + + msg, err := authorizer.Authorize(userAuth.UserId, osUsername) + if err != nil { + logger.Warnf("SSH auth denied: %v", err) + return false + } + + logger.Infof("SSH auth %s", msg) + + key := newAuthKey(osUsername, remoteAddr) + remoteAddrStr := ctx.RemoteAddr().String() s.mu.Lock() s.pendingAuthJWT[key] = userAuth.UserId + s.connections[connKey(remoteAddrStr)] = &connState{ + username: ctx.User(), + remoteAddr: ctx.RemoteAddr(), + jwtUsername: userAuth.UserId, + } s.mu.Unlock() - log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", ctx.User(), userAuth.UserId, ctx.RemoteAddr()) return true } -func (s *Server) markConnectionActivePortForward(sshConn *cryptossh.ServerConn, username, remoteAddr string) { +func (s *Server) addConnectionPortForward(username string, remoteAddr net.Addr, forwardAddr string) { s.mu.Lock() defer s.mu.Unlock() - if state, exists := s.sshConnections[sshConn]; exists { - state.hasActivePortForward = true - } else { - s.sshConnections[sshConn] = &sshConnectionState{ - hasActivePortForward: true, - username: username, - remoteAddr: remoteAddr, + key := connKey(remoteAddr.String()) + if state, exists := s.connections[key]; exists { + if !slices.Contains(state.portForwards, forwardAddr) { + state.portForwards = append(state.portForwards, forwardAddr) } + return + } + + // Connection not in connections (non-JWT auth path) + s.connections[key] = &connState{ + username: username, + remoteAddr: remoteAddr, + portForwards: []string{forwardAddr}, + jwtUsername: s.pendingAuthJWT[newAuthKey(username, remoteAddr)], } } -func (s *Server) connectionCloseHandler(conn net.Conn, err error) { - // We can't extract the SSH connection from net.Conn directly - // Connection cleanup will happen during session cleanup or via timeout - log.Debugf("SSH connection failed for %s: %v", conn.RemoteAddr(), err) +func (s *Server) removeConnectionPortForward(remoteAddr net.Addr, forwardAddr string) { + s.mu.Lock() + defer s.mu.Unlock() + + state, exists := s.connections[connKey(remoteAddr.String())] + if !exists { + return + } + + state.portForwards = slices.DeleteFunc(state.portForwards, func(addr string) bool { + return addr == forwardAddr + }) } -func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey { +// trackedConn wraps a net.Conn to detect when it closes +type trackedConn struct { + net.Conn + server *Server + remoteAddr string + onceClose sync.Once +} + +func (c *trackedConn) Close() error { + err := c.Conn.Close() + c.onceClose.Do(func() { + c.server.handleConnectionClose(c.remoteAddr) + }) + return err +} + +func (s *Server) handleConnectionClose(remoteAddr string) { + s.mu.Lock() + defer s.mu.Unlock() + + key := connKey(remoteAddr) + state, exists := s.connections[key] + if exists && len(state.portForwards) > 0 { + s.connLogger(state).Info("port forwarding connection closed") + } + delete(s.connections, key) +} + +func (s *Server) connLogger(state *connState) *log.Entry { + logger := log.WithField("session", fmt.Sprintf("%s@%s", state.username, state.remoteAddr)) + if state.jwtUsername != "" { + logger = logger.WithField("jwt_user", state.jwtUsername) + } + return logger +} + +func (s *Server) findSessionKeyByContext(ctx ssh.Context) sessionKey { if ctx == nil { return "unknown" } - // Try to match by SSH connection sshConn := ctx.Value(ssh.ContextKeyConn) if sshConn == nil { return "unknown" @@ -554,19 +726,14 @@ func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey { s.mu.RLock() defer s.mu.RUnlock() - // Look through sessions to find one with matching connection - for sessionKey, session := range s.sessions { - if session.Context().Value(ssh.ContextKeyConn) == sshConn { + for sessionKey, state := range s.sessions { + if state.session.Context().Value(ssh.ContextKeyConn) == sshConn { return sessionKey } } - // If no session found, this might be during early connection setup - // Return a temporary key that we'll fix up later if ctx.User() != "" && ctx.RemoteAddr() != nil { - tempKey := SessionKey(fmt.Sprintf("%s@%s", ctx.User(), ctx.RemoteAddr().String())) - log.Debugf("Using temporary session key for early port forward tracking: %s (will be updated when session established)", tempKey) - return tempKey + return sessionKey(fmt.Sprintf("%s@%s", ctx.User(), ctx.RemoteAddr().String())) } return "unknown" @@ -607,7 +774,11 @@ func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn { } log.Infof("SSH connection from NetBird peer %s allowed", tcpAddr) - return conn + return &trackedConn{ + Conn: conn, + server: s, + remoteAddr: conn.RemoteAddr().String(), + } } func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) { @@ -635,9 +806,8 @@ func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) { "tcpip-forward": s.tcpipForwardHandler, "cancel-tcpip-forward": s.cancelTcpipForwardHandler, }, - ConnCallback: s.connectionValidator, - ConnectionFailedCallback: s.connectionCloseHandler, - Version: serverVersion, + ConnCallback: s.connectionValidator, + Version: serverVersion, } if s.jwtEnabled { @@ -653,13 +823,13 @@ func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) { return server, nil } -func (s *Server) storeRemoteForwardListener(key ForwardKey, ln net.Listener) { +func (s *Server) storeRemoteForwardListener(key forwardKey, ln net.Listener) { s.mu.Lock() defer s.mu.Unlock() s.remoteForwardListeners[key] = ln } -func (s *Server) removeRemoteForwardListener(key ForwardKey) bool { +func (s *Server) removeRemoteForwardListener(key forwardKey) bool { s.mu.Lock() defer s.mu.Unlock() @@ -677,6 +847,8 @@ func (s *Server) removeRemoteForwardListener(key ForwardKey) bool { } func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, newChan cryptossh.NewChannel, ctx ssh.Context) { + logger := s.getRequestLogger(ctx) + var payload struct { Host string Port uint32 @@ -686,7 +858,7 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, if err := cryptossh.Unmarshal(newChan.ExtraData(), &payload); err != nil { if err := newChan.Reject(cryptossh.ConnectionFailed, "parse payload"); err != nil { - log.Debugf("channel reject error: %v", err) + logger.Debugf("channel reject error: %v", err) } return } @@ -696,19 +868,20 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, s.mu.RUnlock() if !allowLocal { - log.Warnf("local port forwarding denied for %s:%d: disabled by configuration", payload.Host, payload.Port) + logger.Warnf("local port forwarding denied for %s:%d: disabled", payload.Host, payload.Port) _ = newChan.Reject(cryptossh.Prohibited, "local port forwarding disabled") return } - // Check privilege requirements for the destination port if err := s.checkPortForwardingPrivileges(ctx, "local", payload.Port); err != nil { - log.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err) + logger.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err) _ = newChan.Reject(cryptossh.Prohibited, "insufficient privileges") return } - log.Infof("local port forwarding: %s:%d", payload.Host, payload.Port) + forwardAddr := fmt.Sprintf("-L %s:%d", payload.Host, payload.Port) + s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr) + logger.Infof("local port forwarding: %s:%d", payload.Host, payload.Port) ssh.DirectTCPIPHandler(srv, conn, newChan, ctx) } diff --git a/client/ssh/server/server_config_test.go b/client/ssh/server/server_config_test.go index 24e455025..d85d85a51 100644 --- a/client/ssh/server/server_config_test.go +++ b/client/ssh/server/server_config_test.go @@ -224,6 +224,96 @@ func TestServer_PortForwardingRestriction(t *testing.T) { } } +func TestServer_PrivilegedPortAccess(t *testing.T) { + hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + serverConfig := &Config{ + HostKeyPEM: hostKey, + } + server := New(serverConfig) + server.SetAllowRemotePortForwarding(true) + + tests := []struct { + name string + forwardType string + port uint32 + username string + expectError bool + errorMsg string + skipOnWindows bool + }{ + { + name: "non-root user remote forward privileged port", + forwardType: "remote", + port: 80, + username: "testuser", + expectError: true, + errorMsg: "cannot bind to privileged port", + skipOnWindows: true, + }, + { + name: "non-root user tcpip-forward privileged port", + forwardType: "tcpip-forward", + port: 443, + username: "testuser", + expectError: true, + errorMsg: "cannot bind to privileged port", + skipOnWindows: true, + }, + { + name: "non-root user remote forward unprivileged port", + forwardType: "remote", + port: 8080, + username: "testuser", + expectError: false, + }, + { + name: "non-root user remote forward port 0", + forwardType: "remote", + port: 0, + username: "testuser", + expectError: false, + }, + { + name: "root user remote forward privileged port", + forwardType: "remote", + port: 22, + username: "root", + expectError: false, + }, + { + name: "local forward privileged port allowed for non-root", + forwardType: "local", + port: 80, + username: "testuser", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.skipOnWindows && runtime.GOOS == "windows" { + t.Skip("Windows does not have privileged port restrictions") + } + + result := PrivilegeCheckResult{ + Allowed: true, + User: &user.User{Username: tt.username}, + } + + err := server.checkPrivilegedPortAccess(tt.forwardType, tt.port, result) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + require.NoError(t, err) + } + }) + } +} + func TestServer_PortConflictHandling(t *testing.T) { // Test that multiple sessions requesting the same local port are handled naturally by the OS // Get current user for SSH connection @@ -392,3 +482,95 @@ func TestServer_IsPrivilegedUser(t *testing.T) { }) } } + +func TestServer_PortForwardingOnlySession(t *testing.T) { + // Test that sessions without PTY and command are allowed when port forwarding is enabled + currentUser, err := user.Current() + require.NoError(t, err, "Should be able to get current user") + + // Generate host key for server + hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + tests := []struct { + name string + allowLocalForwarding bool + allowRemoteForwarding bool + expectAllowed bool + description string + }{ + { + name: "session_allowed_with_local_forwarding", + allowLocalForwarding: true, + allowRemoteForwarding: false, + expectAllowed: true, + description: "Port-forwarding-only session should be allowed when local forwarding is enabled", + }, + { + name: "session_allowed_with_remote_forwarding", + allowLocalForwarding: false, + allowRemoteForwarding: true, + expectAllowed: true, + description: "Port-forwarding-only session should be allowed when remote forwarding is enabled", + }, + { + name: "session_allowed_with_both", + allowLocalForwarding: true, + allowRemoteForwarding: true, + expectAllowed: true, + description: "Port-forwarding-only session should be allowed when both forwarding types enabled", + }, + { + name: "session_denied_without_forwarding", + allowLocalForwarding: false, + allowRemoteForwarding: false, + expectAllowed: false, + description: "Port-forwarding-only session should be denied when all forwarding is disabled", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + server.SetAllowRootLogin(true) + server.SetAllowLocalPortForwarding(tt.allowLocalForwarding) + server.SetAllowRemotePortForwarding(tt.allowRemoteForwarding) + + serverAddr := StartTestServer(t, server) + defer func() { + _ = server.Stop() + }() + + // Connect to the server without requesting PTY or command + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + client, err := sshclient.Dial(ctx, serverAddr, currentUser.Username, sshclient.DialOptions{ + InsecureSkipVerify: true, + }) + require.NoError(t, err) + defer func() { + _ = client.Close() + }() + + // Execute a command without PTY - this simulates ssh -T with no command + // The server should either allow it (port forwarding enabled) or reject it + output, err := client.ExecuteCommand(ctx, "") + if tt.expectAllowed { + // When allowed, the session stays open until cancelled + // ExecuteCommand with empty command should return without error + assert.NoError(t, err, "Session should be allowed when port forwarding is enabled") + assert.NotContains(t, output, "port forwarding is disabled", + "Output should not contain port forwarding disabled message") + } else if err != nil { + // When denied, we expect an error message about port forwarding being disabled + assert.Contains(t, err.Error(), "port forwarding is disabled", + "Should get port forwarding disabled message") + } + }) + } +} diff --git a/client/ssh/server/session_handlers.go b/client/ssh/server/session_handlers.go index 4e6d72098..3fd578064 100644 --- a/client/ssh/server/session_handlers.go +++ b/client/ssh/server/session_handlers.go @@ -6,37 +6,45 @@ import ( "errors" "fmt" "io" - "strings" "time" "github.com/gliderlabs/ssh" log "github.com/sirupsen/logrus" - cryptossh "golang.org/x/crypto/ssh" ) +// associateJWTUsername extracts pending JWT username for the session and associates it with the session state. +// Returns the JWT username (empty if none) for logging purposes. +func (s *Server) associateJWTUsername(sess ssh.Session, sessionKey sessionKey) string { + key := newAuthKey(sess.User(), sess.RemoteAddr()) + + s.mu.Lock() + defer s.mu.Unlock() + + jwtUsername := s.pendingAuthJWT[key] + if jwtUsername == "" { + return "" + } + + if state, exists := s.sessions[sessionKey]; exists { + state.jwtUsername = jwtUsername + } + delete(s.pendingAuthJWT, key) + return jwtUsername +} + // sessionHandler handles SSH sessions func (s *Server) sessionHandler(session ssh.Session) { - sessionKey := s.registerSession(session) - - key := newAuthKey(session.User(), session.RemoteAddr()) - s.mu.Lock() - jwtUsername := s.pendingAuthJWT[key] - if jwtUsername != "" { - s.sessionJWTUsers[sessionKey] = jwtUsername - delete(s.pendingAuthJWT, key) - } - s.mu.Unlock() + sessionKey := s.registerSession(session, "") + jwtUsername := s.associateJWTUsername(session, sessionKey) logger := log.WithField("session", sessionKey) if jwtUsername != "" { logger = logger.WithField("jwt_user", jwtUsername) - logger.Infof("SSH session started (JWT user: %s)", jwtUsername) - } else { - logger.Infof("SSH session started") } + logger.Info("SSH session started") sessionStart := time.Now() - defer s.unregisterSession(sessionKey, session) + defer s.unregisterSession(sessionKey) defer func() { duration := time.Since(sessionStart).Round(time.Millisecond) if err := session.Close(); err != nil && !errors.Is(err, io.EOF) { @@ -65,27 +73,52 @@ func (s *Server) sessionHandler(session ssh.Session) { // ssh - non-Pty command execution s.handleCommand(logger, session, privilegeResult, nil) default: - s.rejectInvalidSession(logger, session) + // ssh -T (or ssh -N) - no PTY, no command + s.handleNonInteractiveSession(logger, session) } } -func (s *Server) rejectInvalidSession(logger *log.Entry, session ssh.Session) { - if _, err := io.WriteString(session, "no command specified and Pty not requested\n"); err != nil { - logger.Debugf(errWriteSession, err) +// handleNonInteractiveSession handles sessions that have no PTY and no command. +// These are typically used for port forwarding (ssh -L/-R) or tunneling (ssh -N). +func (s *Server) handleNonInteractiveSession(logger *log.Entry, session ssh.Session) { + s.updateSessionType(session, cmdNonInteractive) + + if !s.isPortForwardingEnabled() { + if _, err := io.WriteString(session, "port forwarding is disabled on this server\n"); err != nil { + logger.Debugf(errWriteSession, err) + } + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + logger.Infof("rejected non-interactive session: port forwarding disabled") + return } - if err := session.Exit(1); err != nil { + + <-session.Context().Done() + + if err := session.Exit(0); err != nil { logSessionExitError(logger, err) } - logger.Infof("rejected non-Pty session without command from %s", session.RemoteAddr()) } -func (s *Server) registerSession(session ssh.Session) SessionKey { +func (s *Server) updateSessionType(session ssh.Session, sessionType string) { + s.mu.Lock() + defer s.mu.Unlock() + + for _, state := range s.sessions { + if state.session == session { + state.sessionType = sessionType + return + } + } +} + +func (s *Server) registerSession(session ssh.Session, sessionType string) sessionKey { sessionID := session.Context().Value(ssh.ContextKeySessionID) if sessionID == nil { sessionID = fmt.Sprintf("%p", session) } - // Create a short 4-byte identifier from the full session ID hasher := sha256.New() hasher.Write([]byte(fmt.Sprintf("%v", sessionID))) hash := hasher.Sum(nil) @@ -93,43 +126,23 @@ func (s *Server) registerSession(session ssh.Session) SessionKey { remoteAddr := session.RemoteAddr().String() username := session.User() - sessionKey := SessionKey(fmt.Sprintf("%s@%s-%s", username, remoteAddr, shortID)) + sessionKey := sessionKey(fmt.Sprintf("%s@%s-%s", username, remoteAddr, shortID)) s.mu.Lock() - s.sessions[sessionKey] = session + s.sessions[sessionKey] = &sessionState{ + session: session, + sessionType: sessionType, + } s.mu.Unlock() return sessionKey } -func (s *Server) unregisterSession(sessionKey SessionKey, session ssh.Session) { +func (s *Server) unregisterSession(sessionKey sessionKey) { s.mu.Lock() + defer s.mu.Unlock() + delete(s.sessions, sessionKey) - delete(s.sessionJWTUsers, sessionKey) - - // Cancel all port forwarding connections for this session - var connectionsToCancel []ConnectionKey - for key := range s.sessionCancels { - if strings.HasPrefix(string(key), string(sessionKey)+"-") { - connectionsToCancel = append(connectionsToCancel, key) - } - } - - for _, key := range connectionsToCancel { - if cancelFunc, exists := s.sessionCancels[key]; exists { - log.WithField("session", sessionKey).Debugf("cancelling port forwarding context: %s", key) - cancelFunc() - delete(s.sessionCancels, key) - } - } - - if sshConnValue := session.Context().Value(ssh.ContextKeyConn); sshConnValue != nil { - if sshConn, ok := sshConnValue.(*cryptossh.ServerConn); ok { - delete(s.sshConnections, sshConn) - } - } - - s.mu.Unlock() } func (s *Server) handlePrivError(logger *log.Entry, session ssh.Session, err error) { diff --git a/client/ssh/server/sftp.go b/client/ssh/server/sftp.go index c2b9f552b..199444abb 100644 --- a/client/ssh/server/sftp.go +++ b/client/ssh/server/sftp.go @@ -18,14 +18,26 @@ func (s *Server) SetAllowSFTP(allow bool) { // sftpSubsystemHandler handles SFTP subsystem requests func (s *Server) sftpSubsystemHandler(sess ssh.Session) { + sessionKey := s.registerSession(sess, cmdSFTP) + defer s.unregisterSession(sessionKey) + + jwtUsername := s.associateJWTUsername(sess, sessionKey) + + logger := log.WithField("session", sessionKey) + if jwtUsername != "" { + logger = logger.WithField("jwt_user", jwtUsername) + } + logger.Info("SFTP session started") + defer logger.Info("SFTP session closed") + s.mu.RLock() allowSFTP := s.allowSFTP s.mu.RUnlock() if !allowSFTP { - log.Debugf("SFTP subsystem request denied: SFTP disabled") + logger.Debug("SFTP subsystem request denied: SFTP disabled") if err := sess.Exit(1); err != nil { - log.Debugf("SFTP session exit failed: %v", err) + logger.Debugf("SFTP session exit: %v", err) } return } @@ -37,31 +49,27 @@ func (s *Server) sftpSubsystemHandler(sess ssh.Session) { }) if !result.Allowed { - log.Warnf("SFTP access denied for user %s from %s: %v", sess.User(), sess.RemoteAddr(), result.Error) + logger.Warnf("SFTP access denied: %v", result.Error) if err := sess.Exit(1); err != nil { - log.Debugf("exit SFTP session: %v", err) + logger.Debugf("exit SFTP session: %v", err) } return } - log.Debugf("SFTP subsystem request from user %s (effective user %s)", sess.User(), result.User.Username) - if !result.RequiresUserSwitching { if err := s.executeSftpDirect(sess); err != nil { - log.Errorf("SFTP direct execution: %v", err) + logger.Errorf("SFTP direct execution: %v", err) } return } if err := s.executeSftpWithPrivilegeDrop(sess, result.User); err != nil { - log.Errorf("SFTP privilege drop execution: %v", err) + logger.Errorf("SFTP privilege drop execution: %v", err) } } // executeSftpDirect executes SFTP directly without privilege dropping func (s *Server) executeSftpDirect(sess ssh.Session) error { - log.Debugf("starting SFTP session for user %s (no privilege dropping)", sess.User()) - sftpServer, err := sftp.NewServer(sess) if err != nil { return fmt.Errorf("SFTP server creation: %w", err) diff --git a/client/ssh/server/test.go b/client/ssh/server/test.go index 20930c721..f8abd1752 100644 --- a/client/ssh/server/test.go +++ b/client/ssh/server/test.go @@ -3,7 +3,6 @@ package server import ( "context" "fmt" - "net" "net/netip" "testing" "time" @@ -14,23 +13,21 @@ func StartTestServer(t *testing.T, server *Server) string { errChan := make(chan error, 1) go func() { - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - errChan <- err - return - } - actualAddr := ln.Addr().String() - if err := ln.Close(); err != nil { - errChan <- fmt.Errorf("close temp listener: %w", err) - return - } - - addrPort := netip.MustParseAddrPort(actualAddr) + // Use port 0 to let the OS assign a free port + addrPort := netip.MustParseAddrPort("127.0.0.1:0") if err := server.Start(context.Background(), addrPort); err != nil { errChan <- err return } - started <- actualAddr + + // Get the actual listening address from the server + actualAddr := server.Addr() + if actualAddr == nil { + errChan <- fmt.Errorf("server started but no listener address available") + return + } + + started <- actualAddr.String() }() select { diff --git a/client/status/status.go b/client/status/status.go index d975f0e29..4f31f3637 100644 --- a/client/status/status.go +++ b/client/status/status.go @@ -82,10 +82,11 @@ type NsServerGroupStateOutput struct { } type SSHSessionOutput struct { - Username string `json:"username" yaml:"username"` - RemoteAddress string `json:"remoteAddress" yaml:"remoteAddress"` - Command string `json:"command" yaml:"command"` - JWTUsername string `json:"jwtUsername,omitempty" yaml:"jwtUsername,omitempty"` + Username string `json:"username" yaml:"username"` + RemoteAddress string `json:"remoteAddress" yaml:"remoteAddress"` + Command string `json:"command" yaml:"command"` + JWTUsername string `json:"jwtUsername,omitempty" yaml:"jwtUsername,omitempty"` + PortForwards []string `json:"portForwards,omitempty" yaml:"portForwards,omitempty"` } type SSHServerStateOutput struct { @@ -220,6 +221,7 @@ func mapSSHServer(sshServerState *proto.SSHServerState) SSHServerStateOutput { RemoteAddress: session.GetRemoteAddress(), Command: session.GetCommand(), JWTUsername: session.GetJwtUsername(), + PortForwards: session.GetPortForwards(), }) } @@ -475,6 +477,9 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, ) } sshServerStatus += "\n " + sessionDisplay + for _, pf := range session.PortForwards { + sshServerStatus += "\n " + pf + } } } } diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 8f99608e7..78934ea95 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -34,6 +34,7 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" + protobuf "google.golang.org/protobuf/proto" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal" @@ -43,7 +44,6 @@ import ( "github.com/netbirdio/netbird/client/ui/desktop" "github.com/netbirdio/netbird/client/ui/event" "github.com/netbirdio/netbird/client/ui/process" - "github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/version" @@ -87,22 +87,24 @@ func main() { // Create the service client (this also builds the settings or networks UI if requested). client := newServiceClient(&newServiceClientArgs{ - addr: flags.daemonAddr, - logFile: logFile, - app: a, - showSettings: flags.showSettings, - showNetworks: flags.showNetworks, - showLoginURL: flags.showLoginURL, - showDebug: flags.showDebug, - showProfiles: flags.showProfiles, - showQuickActions: flags.showQuickActions, + addr: flags.daemonAddr, + logFile: logFile, + app: a, + showSettings: flags.showSettings, + showNetworks: flags.showNetworks, + showLoginURL: flags.showLoginURL, + showDebug: flags.showDebug, + showProfiles: flags.showProfiles, + showQuickActions: flags.showQuickActions, + showUpdate: flags.showUpdate, + showUpdateVersion: flags.showUpdateVersion, }) // Watch for theme/settings changes to update the icon. go watchSettingsChanges(a, client) // Run in window mode if any UI flag was set. - if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles || flags.showQuickActions { + if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles || flags.showQuickActions || flags.showUpdate { a.Run() return } @@ -128,15 +130,17 @@ func main() { } type cliFlags struct { - daemonAddr string - showSettings bool - showNetworks bool - showProfiles bool - showDebug bool - showLoginURL bool - showQuickActions bool - errorMsg string - saveLogsInFile bool + daemonAddr string + showSettings bool + showNetworks bool + showProfiles bool + showDebug bool + showLoginURL bool + showQuickActions bool + errorMsg string + saveLogsInFile bool + showUpdate bool + showUpdateVersion string } // parseFlags reads and returns all needed command-line flags. @@ -156,6 +160,8 @@ func parseFlags() *cliFlags { flag.StringVar(&flags.errorMsg, "error-msg", "", "displays an error message window") flag.BoolVar(&flags.saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", os.TempDir())) flag.BoolVar(&flags.showLoginURL, "login-url", false, "show login URL in a popup window") + flag.BoolVar(&flags.showUpdate, "update", false, "show update progress window") + flag.StringVar(&flags.showUpdateVersion, "update-version", "", "version to update to") flag.Parse() return &flags } @@ -306,6 +312,8 @@ type serviceClient struct { daemonVersion string updateIndicationLock sync.Mutex isUpdateIconActive bool + settingsEnabled bool + profilesEnabled bool showNetworks bool wNetworks fyne.Window wProfiles fyne.Window @@ -319,6 +327,8 @@ type serviceClient struct { mExitNodeDeselectAll *systray.MenuItem logFile string wLoginURL fyne.Window + wUpdateProgress fyne.Window + updateContextCancel context.CancelFunc connectCancel context.CancelFunc } @@ -329,15 +339,17 @@ type menuHandler struct { } type newServiceClientArgs struct { - addr string - logFile string - app fyne.App - showSettings bool - showNetworks bool - showDebug bool - showLoginURL bool - showProfiles bool - showQuickActions bool + addr string + logFile string + app fyne.App + showSettings bool + showNetworks bool + showDebug bool + showLoginURL bool + showProfiles bool + showQuickActions bool + showUpdate bool + showUpdateVersion string } // newServiceClient instance constructor @@ -355,7 +367,7 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient { showAdvancedSettings: args.showSettings, showNetworks: args.showNetworks, - update: version.NewUpdate("nb/client-ui"), + update: version.NewUpdateAndStart("nb/client-ui"), } s.eventHandler = newEventHandler(s) @@ -375,6 +387,8 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient { s.showProfilesUI() case args.showQuickActions: s.showQuickActionsUI() + case args.showUpdate: + s.showUpdateProgress(ctx, args.showUpdateVersion) } return s @@ -814,7 +828,7 @@ func (s *serviceClient) handleSSOLogin(ctx context.Context, loginResp *proto.Log return nil } -func (s *serviceClient) menuUpClick(ctx context.Context) error { +func (s *serviceClient) menuUpClick(ctx context.Context, wannaAutoUpdate bool) error { systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting) conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { @@ -836,7 +850,9 @@ func (s *serviceClient) menuUpClick(ctx context.Context) error { return nil } - if _, err := conn.Up(ctx, &proto.UpRequest{}); err != nil { + if _, err := s.conn.Up(s.ctx, &proto.UpRequest{ + AutoUpdate: protobuf.Bool(wannaAutoUpdate), + }); err != nil { return fmt.Errorf("start connection: %w", err) } @@ -893,7 +909,7 @@ func (s *serviceClient) updateStatus() error { var systrayIconState bool switch { - case status.Status == string(internal.StatusConnected): + case status.Status == string(internal.StatusConnected) && !s.connected: s.connected = true s.sendNotification = true if s.isUpdateIconActive { @@ -907,6 +923,7 @@ func (s *serviceClient) updateStatus() error { s.mUp.Disable() s.mDown.Enable() s.mNetworks.Enable() + s.mExitNode.Enable() go s.updateExitNodes() systrayIconState = true case status.Status == string(internal.StatusConnecting): @@ -1097,6 +1114,26 @@ func (s *serviceClient) onTrayReady() { s.updateExitNodes() } }) + s.eventManager.AddHandler(func(event *proto.SystemEvent) { + // todo use new Category + if windowAction, ok := event.Metadata["progress_window"]; ok { + targetVersion, ok := event.Metadata["version"] + if !ok { + targetVersion = "unknown" + } + log.Debugf("window action: %v", windowAction) + if windowAction == "show" { + if s.updateContextCancel != nil { + s.updateContextCancel() + s.updateContextCancel = nil + } + + subCtx, cancel := context.WithCancel(s.ctx) + go s.eventHandler.runSelfCommand(subCtx, "update", "--update-version", targetVersion) + s.updateContextCancel = cancel + } + } + }) go s.eventManager.Start(s.ctx) go s.eventHandler.listen(s.ctx) @@ -1240,19 +1277,22 @@ func (s *serviceClient) checkAndUpdateFeatures() { return } + s.updateIndicationLock.Lock() + defer s.updateIndicationLock.Unlock() + // Update settings menu based on current features - if features != nil && features.DisableUpdateSettings { - s.setSettingsEnabled(false) - } else { - s.setSettingsEnabled(true) + settingsEnabled := features == nil || !features.DisableUpdateSettings + if s.settingsEnabled != settingsEnabled { + s.settingsEnabled = settingsEnabled + s.setSettingsEnabled(settingsEnabled) } // Update profile menu based on current features if s.mProfile != nil { - if features != nil && features.DisableProfiles { - s.mProfile.setEnabled(false) - } else { - s.mProfile.setEnabled(true) + profilesEnabled := features == nil || !features.DisableProfiles + if s.profilesEnabled != profilesEnabled { + s.profilesEnabled = profilesEnabled + s.mProfile.setEnabled(profilesEnabled) } } } diff --git a/client/ui/event_handler.go b/client/ui/event_handler.go index e0b619411..9ffacd926 100644 --- a/client/ui/event_handler.go +++ b/client/ui/event_handler.go @@ -80,7 +80,7 @@ func (h *eventHandler) handleConnectClick() { go func() { defer connectCancel() - if err := h.client.menuUpClick(connectCtx); err != nil { + if err := h.client.menuUpClick(connectCtx, true); err != nil { st, ok := status.FromError(err) if errors.Is(err, context.Canceled) || (ok && st.Code() == codes.Canceled) { log.Debugf("connect operation cancelled by user") @@ -185,7 +185,7 @@ func (h *eventHandler) handleAdvancedSettingsClick() { go func() { defer h.client.mAdvancedSettings.Enable() defer h.client.getSrvConfig() - h.runSelfCommand(h.client.ctx, "settings", "true") + h.runSelfCommand(h.client.ctx, "settings") }() } @@ -193,7 +193,7 @@ func (h *eventHandler) handleCreateDebugBundleClick() { h.client.mCreateDebugBundle.Disable() go func() { defer h.client.mCreateDebugBundle.Enable() - h.runSelfCommand(h.client.ctx, "debug", "true") + h.runSelfCommand(h.client.ctx, "debug") }() } @@ -217,7 +217,7 @@ func (h *eventHandler) handleNetworksClick() { h.client.mNetworks.Disable() go func() { defer h.client.mNetworks.Enable() - h.runSelfCommand(h.client.ctx, "networks", "true") + h.runSelfCommand(h.client.ctx, "networks") }() } @@ -237,17 +237,21 @@ func (h *eventHandler) updateConfigWithErr() error { return nil } -func (h *eventHandler) runSelfCommand(ctx context.Context, command, arg string) { +func (h *eventHandler) runSelfCommand(ctx context.Context, command string, args ...string) { proc, err := os.Executable() if err != nil { log.Errorf("error getting executable path: %v", err) return } - cmd := exec.CommandContext(ctx, proc, - fmt.Sprintf("--%s=%s", command, arg), + // Build the full command arguments + cmdArgs := []string{ + fmt.Sprintf("--%s=true", command), fmt.Sprintf("--daemon-addr=%s", h.client.addr), - ) + } + cmdArgs = append(cmdArgs, args...) + + cmd := exec.CommandContext(ctx, proc, cmdArgs...) if out := h.client.attachOutput(cmd); out != nil { defer func() { @@ -257,17 +261,17 @@ func (h *eventHandler) runSelfCommand(ctx context.Context, command, arg string) }() } - log.Printf("running command: %s --%s=%s --daemon-addr=%s", proc, command, arg, h.client.addr) + log.Printf("running command: %s", cmd.String()) if err := cmd.Run(); err != nil { var exitErr *exec.ExitError if errors.As(err, &exitErr) { - log.Printf("command '%s %s' failed with exit code %d", command, arg, exitErr.ExitCode()) + log.Printf("command '%s' failed with exit code %d", cmd.String(), exitErr.ExitCode()) } return } - log.Printf("command '%s %s' completed successfully", command, arg) + log.Printf("command '%s' completed successfully", cmd.String()) } func (h *eventHandler) logout(ctx context.Context) error { diff --git a/client/ui/font_windows.go b/client/ui/font_windows.go index 93b23a21b..6346a9fb9 100644 --- a/client/ui/font_windows.go +++ b/client/ui/font_windows.go @@ -31,7 +31,6 @@ func (s *serviceClient) getWindowsFontFilePath() string { "chr-CHER-US": "Gadugi.ttf", "zh-HK": "Segoeui.ttf", "zh-TW": "Segoeui.ttf", - "ja-JP": "Yugothm.ttc", "km-KH": "Leelawui.ttf", "ko-KR": "Malgun.ttf", "th-TH": "Leelawui.ttf", diff --git a/client/ui/profile.go b/client/ui/profile.go index 74189c9a0..a38d8918a 100644 --- a/client/ui/profile.go +++ b/client/ui/profile.go @@ -397,7 +397,7 @@ type profileMenu struct { logoutSubItem *subItem profilesState []Profile downClickCallback func() error - upClickCallback func(context.Context) error + upClickCallback func(context.Context, bool) error getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error) loadSettingsCallback func() app fyne.App @@ -411,7 +411,7 @@ type newProfileMenuArgs struct { profileMenuItem *systray.MenuItem emailMenuItem *systray.MenuItem downClickCallback func() error - upClickCallback func(context.Context) error + upClickCallback func(context.Context, bool) error getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error) loadSettingsCallback func() app fyne.App @@ -579,7 +579,7 @@ func (p *profileMenu) refresh() { connectCtx, connectCancel := context.WithCancel(p.ctx) p.serviceClient.connectCancel = connectCancel - if err := p.upClickCallback(connectCtx); err != nil { + if err := p.upClickCallback(connectCtx, false); err != nil { log.Errorf("failed to handle up click after switching profile: %v", err) } diff --git a/client/ui/quickactions.go b/client/ui/quickactions.go index bf47ac434..76440d684 100644 --- a/client/ui/quickactions.go +++ b/client/ui/quickactions.go @@ -267,7 +267,7 @@ func (s *serviceClient) showQuickActionsUI() { connCmd := connectCommand{ connectClient: func() error { - return s.menuUpClick(s.ctx) + return s.menuUpClick(s.ctx, false) }, } diff --git a/client/ui/update.go b/client/ui/update.go new file mode 100644 index 000000000..25c317bdf --- /dev/null +++ b/client/ui/update.go @@ -0,0 +1,140 @@ +//go:build !(linux && 386) + +package main + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + "fyne.io/fyne/v2/container" + "fyne.io/fyne/v2/widget" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/proto" +) + +func (s *serviceClient) showUpdateProgress(ctx context.Context, version string) { + log.Infof("show installer progress window: %s", version) + s.wUpdateProgress = s.app.NewWindow("Automatically updating client") + + statusLabel := widget.NewLabel("Updating...") + infoLabel := widget.NewLabel(fmt.Sprintf("Your client version is older than the auto-update version set in Management.\nUpdating client to: %s.", version)) + content := container.NewVBox(infoLabel, statusLabel) + s.wUpdateProgress.SetContent(content) + s.wUpdateProgress.CenterOnScreen() + s.wUpdateProgress.SetFixedSize(true) + s.wUpdateProgress.SetCloseIntercept(func() { + // this is empty to lock window until result known + }) + s.wUpdateProgress.RequestFocus() + s.wUpdateProgress.Show() + + updateWindowCtx, cancel := context.WithTimeout(ctx, 15*time.Minute) + + // Initialize dot updater + updateText := dotUpdater() + + // Channel to receive the result from RPC call + resultErrCh := make(chan error, 1) + resultOkCh := make(chan struct{}, 1) + + // Start RPC call in background + go func() { + conn, err := s.getSrvClient(defaultFailTimeout) + if err != nil { + log.Infof("backend not reachable, upgrade in progress: %v", err) + close(resultOkCh) + return + } + + resp, err := conn.GetInstallerResult(updateWindowCtx, &proto.InstallerResultRequest{}) + if err != nil { + log.Infof("backend stopped responding, upgrade in progress: %v", err) + close(resultOkCh) + return + } + + if !resp.Success { + resultErrCh <- mapInstallError(resp.ErrorMsg) + return + } + + // Success + close(resultOkCh) + }() + + // Update UI with dots and wait for result + go func() { + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + defer cancel() + + // allow closing update window after 10 sec + timerResetCloseInterceptor := time.NewTimer(10 * time.Second) + defer timerResetCloseInterceptor.Stop() + + for { + select { + case <-updateWindowCtx.Done(): + s.showInstallerResult(statusLabel, updateWindowCtx.Err()) + return + case err := <-resultErrCh: + s.showInstallerResult(statusLabel, err) + return + case <-resultOkCh: + log.Info("backend exited, upgrade in progress, closing all UI") + killParentUIProcess() + s.app.Quit() + return + case <-ticker.C: + statusLabel.SetText(updateText()) + case <-timerResetCloseInterceptor.C: + s.wUpdateProgress.SetCloseIntercept(nil) + } + } + }() +} + +func (s *serviceClient) showInstallerResult(statusLabel *widget.Label, err error) { + s.wUpdateProgress.SetCloseIntercept(nil) + switch { + case errors.Is(err, context.DeadlineExceeded): + log.Warn("update watcher timed out") + statusLabel.SetText("Update timed out. Please try again.") + case errors.Is(err, context.Canceled): + log.Info("update watcher canceled") + statusLabel.SetText("Update canceled.") + case err != nil: + log.Errorf("update failed: %v", err) + statusLabel.SetText("Update failed: " + err.Error()) + default: + s.wUpdateProgress.Close() + } +} + +// dotUpdater returns a closure that cycles through dots for a loading animation. +func dotUpdater() func() string { + dotCount := 0 + return func() string { + dotCount = (dotCount + 1) % 4 + return fmt.Sprintf("%s%s", "Updating", strings.Repeat(".", dotCount)) + } +} + +func mapInstallError(msg string) error { + msg = strings.ToLower(strings.TrimSpace(msg)) + + switch { + case strings.Contains(msg, "deadline exceeded"), strings.Contains(msg, "timeout"): + return context.DeadlineExceeded + case strings.Contains(msg, "canceled"), strings.Contains(msg, "cancelled"): + return context.Canceled + case msg == "": + return errors.New("unknown update error") + default: + return errors.New(msg) + } +} diff --git a/client/ui/update_notwindows.go b/client/ui/update_notwindows.go new file mode 100644 index 000000000..5766f18f7 --- /dev/null +++ b/client/ui/update_notwindows.go @@ -0,0 +1,7 @@ +//go:build !windows && !(linux && 386) + +package main + +func killParentUIProcess() { + // No-op on non-Windows platforms +} diff --git a/client/ui/update_windows.go b/client/ui/update_windows.go new file mode 100644 index 000000000..1b03936f9 --- /dev/null +++ b/client/ui/update_windows.go @@ -0,0 +1,44 @@ +//go:build windows + +package main + +import ( + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" + + nbprocess "github.com/netbirdio/netbird/client/ui/process" +) + +// killParentUIProcess finds and kills the parent systray UI process on Windows. +// This is a workaround in case the MSI installer fails to properly terminate the UI process. +// The installer should handle this via util:CloseApplication with TerminateProcess, but this +// provides an additional safety mechanism to ensure the UI is closed before the upgrade proceeds. +func killParentUIProcess() { + pid, running, err := nbprocess.IsAnotherProcessRunning() + if err != nil { + log.Warnf("failed to check for parent UI process: %v", err) + return + } + + if !running { + log.Debug("no parent UI process found to kill") + return + } + + log.Infof("killing parent UI process (PID: %d)", pid) + + // Open the process with terminate rights + handle, err := windows.OpenProcess(windows.PROCESS_TERMINATE, false, uint32(pid)) + if err != nil { + log.Warnf("failed to open parent process %d: %v", pid, err) + return + } + defer func() { + _ = windows.CloseHandle(handle) + }() + + // Terminate the process with exit code 0 + if err := windows.TerminateProcess(handle, 0); err != nil { + log.Warnf("failed to terminate parent process %d: %v", pid, err) + } +} diff --git a/go.mod b/go.mod index 8f4ec530b..23cf0f37d 100644 --- a/go.mod +++ b/go.mod @@ -8,22 +8,22 @@ require ( github.com/cloudflare/circl v1.3.3 // indirect github.com/golang/protobuf v1.5.4 github.com/google/uuid v1.6.0 - github.com/gorilla/mux v1.8.0 + github.com/gorilla/mux v1.8.1 github.com/kardianos/service v1.2.3-0.20240613133416-becf2eb62b83 github.com/onsi/ginkgo v1.16.5 github.com/onsi/gomega v1.27.6 github.com/rs/cors v1.8.0 github.com/sirupsen/logrus v1.9.3 - github.com/spf13/cobra v1.7.0 - github.com/spf13/pflag v1.0.5 + github.com/spf13/cobra v1.10.1 + github.com/spf13/pflag v1.0.9 github.com/vishvananda/netlink v1.3.1 - golang.org/x/crypto v0.45.0 - golang.org/x/sys v0.38.0 + golang.org/x/crypto v0.46.0 + golang.org/x/sys v0.39.0 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 - google.golang.org/grpc v1.73.0 - google.golang.org/protobuf v1.36.8 + google.golang.org/grpc v1.77.0 + google.golang.org/protobuf v1.36.10 gopkg.in/natefinch/lumberjack.v2 v2.0.0 ) @@ -41,6 +41,8 @@ require ( github.com/coder/websocket v1.8.13 github.com/coreos/go-iptables v0.7.0 github.com/creack/pty v1.1.18 + github.com/dexidp/dex v0.0.0-00010101000000-000000000000 + github.com/dexidp/dex/api/v2 v2.4.0 github.com/eko/gocache/lib/v4 v4.2.0 github.com/eko/gocache/store/go_cache/v4 v4.2.2 github.com/eko/gocache/store/redis/v4 v4.2.2 @@ -78,7 +80,7 @@ require ( github.com/pion/transport/v3 v3.0.7 github.com/pion/turn/v3 v3.0.1 github.com/pkg/sftp v1.13.9 - github.com/prometheus/client_golang v1.22.0 + github.com/prometheus/client_golang v1.23.2 github.com/quic-go/quic-go v0.49.1 github.com/redis/go-redis/v9 v9.7.3 github.com/rs/xid v1.3.0 @@ -96,11 +98,11 @@ require ( github.com/vmihailenco/msgpack/v5 v5.4.1 github.com/yusufpapurcu/wmi v1.2.4 github.com/zcalusic/sysinfo v1.1.3 - go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 - go.opentelemetry.io/otel v1.35.0 + go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0 + go.opentelemetry.io/otel v1.38.0 go.opentelemetry.io/otel/exporters/prometheus v0.48.0 - go.opentelemetry.io/otel/metric v1.35.0 - go.opentelemetry.io/otel/sdk/metric v1.35.0 + go.opentelemetry.io/otel/metric v1.38.0 + go.opentelemetry.io/otel/sdk/metric v1.38.0 go.uber.org/mock v0.5.0 go.uber.org/zap v1.27.0 goauthentik.io/api/v3 v3.2023051.3 @@ -108,11 +110,11 @@ require ( golang.org/x/mobile v0.0.0-20251113184115-a159579294ab golang.org/x/mod v0.30.0 golang.org/x/net v0.47.0 - golang.org/x/oauth2 v0.30.0 - golang.org/x/sync v0.18.0 - golang.org/x/term v0.37.0 - golang.org/x/time v0.12.0 - google.golang.org/api v0.177.0 + golang.org/x/oauth2 v0.34.0 + golang.org/x/sync v0.19.0 + golang.org/x/term v0.38.0 + golang.org/x/time v0.14.0 + google.golang.org/api v0.257.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/mysql v1.5.7 gorm.io/driver/postgres v1.5.7 @@ -122,13 +124,18 @@ require ( ) require ( - cloud.google.com/go/auth v0.3.0 // indirect - cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect - cloud.google.com/go/compute/metadata v0.6.0 // indirect - dario.cat/mergo v1.0.0 // indirect + cloud.google.com/go/auth v0.17.0 // indirect + cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect + cloud.google.com/go/compute/metadata v0.9.0 // indirect + dario.cat/mergo v1.0.1 // indirect filippo.io/edwards25519 v1.1.0 // indirect + github.com/AppsFlyer/go-sundheit v0.6.0 // indirect github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect + github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect github.com/BurntSushi/toml v1.5.0 // indirect + github.com/Masterminds/goutils v1.1.1 // indirect + github.com/Masterminds/semver/v3 v3.3.0 // indirect + github.com/Masterminds/sprig/v3 v3.3.0 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect github.com/Microsoft/hcsshim v0.12.3 // indirect github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect @@ -149,12 +156,14 @@ require ( github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 // indirect github.com/aws/smithy-go v1.22.2 // indirect + github.com/beevik/etree v1.6.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/caddyserver/zerossl v0.1.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/containerd/containerd v1.7.29 // indirect github.com/containerd/log v0.1.0 // indirect github.com/containerd/platforms v0.2.1 // indirect + github.com/coreos/go-oidc/v3 v3.14.1 // indirect github.com/cpuguy83/dockercfg v0.3.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect @@ -168,26 +177,30 @@ require ( github.com/fyne-io/glfw-js v0.3.0 // indirect github.com/fyne-io/image v0.1.1 // indirect github.com/fyne-io/oksvg v0.2.0 // indirect + github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 // indirect github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect - github.com/go-logr/logr v1.4.2 // indirect + github.com/go-jose/go-jose/v4 v4.1.3 // indirect + github.com/go-ldap/ldap/v3 v3.4.12 // indirect + github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.3.0 // indirect - github.com/go-sql-driver/mysql v1.8.1 // indirect + github.com/go-sql-driver/mysql v1.9.3 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/go-text/render v0.2.0 // indirect github.com/go-text/typesetting v0.2.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect - github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/google/btree v1.1.2 // indirect github.com/google/pprof v0.0.0-20211214055906-6f57359322fd // indirect - github.com/google/s2a-go v0.1.7 // indirect - github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect - github.com/googleapis/gax-go/v2 v2.12.3 // indirect + github.com/google/s2a-go v0.1.9 // indirect + github.com/googleapis/enterprise-certificate-proxy v0.3.7 // indirect + github.com/googleapis/gax-go/v2 v2.15.0 // indirect + github.com/gorilla/handlers v1.5.2 // indirect github.com/hack-pad/go-indexeddb v0.3.2 // indirect github.com/hack-pad/safejs v0.1.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-uuid v1.0.3 // indirect + github.com/huandu/xstrings v1.5.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect @@ -196,18 +209,23 @@ require ( github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect + github.com/jonboulle/clockwork v0.5.0 // indirect github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 // indirect github.com/kelseyhightower/envconfig v1.4.0 // indirect github.com/klauspost/compress v1.18.0 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect github.com/kr/fs v0.1.0 // indirect + github.com/lib/pq v1.10.9 // indirect github.com/libdns/libdns v0.2.2 // indirect github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect github.com/magiconair/properties v1.8.7 // indirect - github.com/mattn/go-sqlite3 v1.14.22 // indirect + github.com/mattermost/xml-roundtrip-validator v0.1.0 // indirect + github.com/mattn/go-sqlite3 v1.14.32 // indirect github.com/mdlayher/genetlink v1.3.2 // indirect github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 // indirect github.com/mholt/acmez/v2 v2.0.1 // indirect + github.com/mitchellh/copystructure v1.2.0 // indirect + github.com/mitchellh/reflectwalk v1.0.2 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/patternmatcher v0.6.0 // indirect github.com/moby/sys/sequential v0.5.0 // indirect @@ -230,11 +248,14 @@ require ( github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect - github.com/prometheus/client_model v0.6.1 // indirect - github.com/prometheus/common v0.62.0 // indirect - github.com/prometheus/procfs v0.15.1 // indirect + github.com/prometheus/client_model v0.6.2 // indirect + github.com/prometheus/common v0.66.1 // indirect + github.com/prometheus/procfs v0.16.1 // indirect + github.com/russellhaering/goxmldsig v1.5.0 // indirect github.com/rymdport/portal v0.4.2 // indirect github.com/shoenig/go-m1cpu v0.1.6 // indirect + github.com/shopspring/decimal v1.4.0 // indirect + github.com/spf13/cast v1.7.0 // indirect github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect github.com/stretchr/objx v0.5.2 // indirect @@ -245,17 +266,17 @@ require ( github.com/wlynxg/anet v0.0.3 // indirect github.com/yuin/goldmark v1.7.8 // indirect github.com/zeebo/blake3 v0.2.3 // indirect - go.opencensus.io v0.24.0 // indirect - go.opentelemetry.io/auto/sdk v1.1.0 // indirect - go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect - go.opentelemetry.io/otel/sdk v1.35.0 // indirect - go.opentelemetry.io/otel/trace v1.35.0 // indirect + go.opentelemetry.io/auto/sdk v1.2.1 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 // indirect + go.opentelemetry.io/otel/sdk v1.38.0 // indirect + go.opentelemetry.io/otel/trace v1.38.0 // indirect go.uber.org/multierr v1.11.0 // indirect + go.yaml.in/yaml/v2 v2.4.2 // indirect golang.org/x/image v0.33.0 // indirect - golang.org/x/text v0.31.0 // indirect + golang.org/x/text v0.32.0 // indirect golang.org/x/tools v0.39.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20251124214823-79d6a2a48846 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect ) @@ -271,3 +292,5 @@ replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-2023080111 replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 + +replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.0 diff --git a/go.sum b/go.sum index f10e1e6da..354c7732e 100644 --- a/go.sum +++ b/go.sum @@ -1,15 +1,14 @@ -cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go/auth v0.3.0 h1:PRyzEpGfx/Z9e8+lHsbkoUVXD0gnu4MNmm7Gp8TQNIs= -cloud.google.com/go/auth v0.3.0/go.mod h1:lBv6NKTWp8E3LPzmO1TbiiRKc4drLOfHsgmlH9ogv5w= -cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKFt//oWu7HX4= -cloud.google.com/go/auth/oauth2adapt v0.2.2/go.mod h1:wcYjgpZI9+Yu7LyYBg4pqSiaRkfEK3GQcpb7C/uyF1Q= +cloud.google.com/go/auth v0.17.0 h1:74yCm7hCj2rUyyAocqnFzsAYXgJhrG26XCFimrc/Kz4= +cloud.google.com/go/auth v0.17.0/go.mod h1:6wv/t5/6rOPAX4fJiRjKkJCvswLwdet7G8+UGXt7nCQ= +cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc= +cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c= cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= -cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I= -cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg= +cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs= +cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= cunicu.li/go-rosenpass v0.4.0 h1:LtPtBgFWY/9emfgC4glKLEqS0MJTylzV6+ChRhiZERw= cunicu.li/go-rosenpass v0.4.0/go.mod h1:MPbjH9nxV4l3vEagKVdFNwHOketqgS5/To1VYJplf/M= -dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= -dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= +dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s= +dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= fyne.io/fyne/v2 v2.7.0 h1:GvZSpE3X0liU/fqstInVvRsaboIVpIWQ4/sfjDGIGGQ= @@ -18,17 +17,28 @@ fyne.io/systray v1.11.1-0.20250603113521-ca66a66d8b58 h1:eA5/u2XRd8OUkoMqEv3IBlF fyne.io/systray v1.11.1-0.20250603113521-ca66a66d8b58/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs= github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9vkmnHYOMsOr4WLk+Vo07yKIzd94sVoIqshQ4bU= github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8= +github.com/AppsFlyer/go-sundheit v0.6.0 h1:d2hBvCjBSb2lUsEWGfPigr4MCOt04sxB+Rppl0yUMSk= +github.com/AppsFlyer/go-sundheit v0.6.0/go.mod h1:LDdBHD6tQBtmHsdW+i1GwdTt6Wqc0qazf5ZEJVTbTME= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= -github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 h1:mFRzDkZVAjdal+s7s0MwaRv9igoPqLRdzOLzw/8Xvq8= +github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358/go.mod h1:chxPXzSsl7ZWRAuOIE23GDNzjWuZquvFlgA8xmpunjU= github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg= github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= +github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI= +github.com/Masterminds/goutils v1.1.1/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU= +github.com/Masterminds/semver/v3 v3.3.0 h1:B8LGeaivUe71a5qox1ICM/JLl0NqZSW5CHyL+hmvYS0= +github.com/Masterminds/semver/v3 v3.3.0/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM= +github.com/Masterminds/sprig/v3 v3.3.0 h1:mQh0Yrg1XPo6vjYXgtf5OtijNAKJRNcTdOOGZe3tPhs= +github.com/Masterminds/sprig/v3 v3.3.0/go.mod h1:Zy1iXRYNqNLUolqCpL4uhk6SHUMAOSCzdgBfDb35Lz0= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/Microsoft/hcsshim v0.12.3 h1:LS9NXqXhMoqNCplK1ApmVSfB4UnVLRDWRapB6EIlxE0= github.com/Microsoft/hcsshim v0.12.3/go.mod h1:Iyl1WVpZzr+UkzjekHZbV8o5Z9ZkxNGx6CtY2Qg/JVQ= github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible h1:hqcTK6ZISdip65SR792lwYJTa/axESA0889D3UlZbLo= github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible/go.mod h1:6B1nuc1MUs6c62ODZDl7hVE5Pv7O2XGSkgg2olnq34I= +github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e h1:4dAU9FXIyQktpoUAgOJK3OTFc/xug0PCXYCqU0FgDKI= +github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= github.com/awnumar/memcall v0.4.0 h1:B7hgZYdfH6Ot1Goaz8jGne/7i8xD4taZie/PNSFZ29g= @@ -73,6 +83,8 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 h1:1XuUZ8mYJw9B6lzAkXhqHlJd/Xv github.com/aws/aws-sdk-go-v2/service/sts v1.33.19/go.mod h1:cQnB8CUnxbMU82JvlqjKR2HBOm3fe9pWorWBza6MBJ4= github.com/aws/smithy-go v1.22.2 h1:6D9hW43xKFrRx/tXXfAlIZc4JI+yQe6snnWcQyxSyLQ= github.com/aws/smithy-go v1.22.2/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= +github.com/beevik/etree v1.6.0 h1:u8Kwy8pp9D9XeITj2Z0XtA5qqZEmtJtuXZRQi+j03eE= +github.com/beevik/etree v1.6.0/go.mod h1:bh4zJxiIr62SOf9pRzN7UUYaEDa9HEKafK25+sLc0Gc= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= @@ -87,7 +99,6 @@ github.com/caddyserver/zerossl v0.1.3 h1:onS+pxp3M8HnHpN5MMbOMyNjmTheJyWRaZYwn+Y github.com/caddyserver/zerossl v0.1.3/go.mod h1:CxA0acn7oEGO6//4rtrRjYgEoa4MFw/XofZnrYwGqG4= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= -github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= @@ -95,8 +106,6 @@ github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5P github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk= github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso= -github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE= github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/containerd/containerd v1.7.29 h1:90fWABQsaN9mJhGkoVnuzEY+o1XDPbg9BTC9QTAHnuE= @@ -107,9 +116,11 @@ github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpS github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw= github.com/coreos/go-iptables v0.7.0 h1:XWM3V+MPRr5/q51NuWSgU0fqMad64Zyxs8ZUoMsamr8= github.com/coreos/go-iptables v0.7.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q= +github.com/coreos/go-oidc/v3 v3.14.1 h1:9ePWwfdwC4QKRlCXsJGou56adA/owXczOzwKdOumLqk= +github.com/coreos/go-oidc/v3 v3.14.1/go.mod h1:HaZ3szPaZ0e4r6ebqvsLWlk2Tn+aejfmrfah6hnSYEU= github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA= github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc= -github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:/DS5cDX3FJdl+XaN2D7XAwFpuanTxnp52DBLZAaJKx0= @@ -117,6 +128,8 @@ github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70J github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dexidp/dex/api/v2 v2.4.0 h1:gNba7n6BKVp8X4Jp24cxYn5rIIGhM6kDOXcZoL6tr9A= +github.com/dexidp/dex/api/v2 v2.4.0/go.mod h1:/p550ADvFFh7K95VmhUD+jgm15VdaNnab9td8DHOpyI= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= @@ -133,14 +146,14 @@ github.com/eko/gocache/store/go_cache/v4 v4.2.2 h1:tAI9nl6TLoJyKG1ujF0CS0n/IgTEM github.com/eko/gocache/store/go_cache/v4 v4.2.2/go.mod h1:T9zkHokzr8K9EiC7RfMbDg6HSwaV6rv3UdcNu13SGcA= github.com/eko/gocache/store/redis/v4 v4.2.2 h1:Thw31fzGuH3WzJywsdbMivOmP550D6JS7GDHhvCJPA0= github.com/eko/gocache/store/redis/v4 v4.2.2/go.mod h1:LaTxLKx9TG/YUEybQvPMij++D7PBTIJ4+pzvk0ykz0w= -github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= -github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/felixge/fgprof v0.9.3 h1:VvyZxILNuCiUCSXtPtYmmtGvb65nqXh2QFWc0Wpf2/g= github.com/felixge/fgprof v0.9.3/go.mod h1:RdbpDgzqYVh/T9fPELJyV7EYJuHB55UTEULNun8eiPw= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw= +github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fredbi/uri v1.1.1 h1:xZHJC08GZNIUhbP5ImTHnt5Ya0T8FI2VAwI/37kh2Ko= github.com/fredbi/uri v1.1.1/go.mod h1:4+DZQ5zBjEwQCDmXW5JdIjz0PUA+yJbvtBv+u+adr5o= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= @@ -159,13 +172,19 @@ github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm github.com/gin-gonic/gin v1.5.0/go.mod h1:Nd6IXA8m5kNZdNEHMBd93KT+mdY3+bewLgRvmCsR2Do= github.com/gliderlabs/ssh v0.3.8 h1:a4YXD1V7xMF9g5nTkdfnja3Sxy1PVDCj1Zg4Wb8vY6c= github.com/gliderlabs/ssh v0.3.8/go.mod h1:xYoytBv1sV0aL3CavoDuJIQNURXkkfPA/wxQ1pL1fAU= +github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 h1:BP4M0CvQ4S3TGls2FvczZtj5Re/2ZzkV9VwqPHH/3Bo= +github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0= github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 h1:5BVwOaUSBTlVZowGO6VZGw2H/zl9nrd3eCZfYV+NfQA= github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71/go.mod h1:9YTyiznxEY1fVinfM7RvRcjRHbw2xLBJ3AAGIT0I4Nw= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a h1:vxnBhFDDT+xzxf1jTJKMKZw3H0swfWk9RpWbBbDK5+0= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= +github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= +github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= +github.com/go-ldap/ldap/v3 v3.4.12 h1:1b81mv7MagXZ7+1r7cLTWmyuTqVqdwbtJSjC0DAp9s4= +github.com/go-ldap/ldap/v3 v3.4.12/go.mod h1:+SPAGcTtOfmGsCb3h1RFiq4xpp4N636G75OEace8lNo= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= -github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= -github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= @@ -178,8 +197,8 @@ github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZs github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= -github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= -github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= +github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= +github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= @@ -195,11 +214,6 @@ github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= -github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -210,9 +224,7 @@ github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:x github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= -github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= @@ -220,12 +232,9 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= -github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= @@ -240,23 +249,24 @@ github.com/google/nftables v0.3.0 h1:bkyZ0cbpVeMHXOrtlFc8ISmfVqq5gPJukoYieyVmITg github.com/google/nftables v0.3.0/go.mod h1:BCp9FsrbF1Fn/Yu6CLUc9GGZFw/+hsxfluNXXmxBfRM= github.com/google/pprof v0.0.0-20211214055906-6f57359322fd h1:1FjCyPC+syAzJ5/2S8fqdZK1R22vvA0J7JZKcuOIQ7Y= github.com/google/pprof v0.0.0-20211214055906-6f57359322fd/go.mod h1:KgnwoLYCZ8IQu3XUZ8Nc/bM9CCZFOyjUNOSygVozoDg= -github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o= -github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw= -github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= +github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs= -github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0= -github.com/googleapis/gax-go/v2 v2.12.3 h1:5/zPPDvw8Q1SuXjrqrZslrqT7dL/uJT2CQii/cLCKqA= -github.com/googleapis/gax-go/v2 v2.12.3/go.mod h1:AKloxT6GtNbaLm8QTNSidHUVsHYcBHwWRvkNFJUQcS4= +github.com/googleapis/enterprise-certificate-proxy v0.3.7 h1:zrn2Ee/nWmHulBx5sAVrGgAa0f2/R35S4DJwfFaUPFQ= +github.com/googleapis/enterprise-certificate-proxy v0.3.7/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA= +github.com/googleapis/gax-go/v2 v2.15.0 h1:SyjDc1mGgZU5LncH8gimWo9lW1DtIfPibOG81vgd/bo= +github.com/googleapis/gax-go/v2 v2.15.0/go.mod h1:zVVkkxAQHa1RQpg9z2AUCMnKhi0Qld9rcmyfL1OZhoc= github.com/gopacket/gopacket v1.1.1 h1:zbx9F9d6A7sWNkFKrvMBZTfGgxFoY4NgUudFVVHMfcw= github.com/gopacket/gopacket v1.1.1/go.mod h1:HavMeONEl7W9036of9LbSWoonqhH7HA1+ZRO+rMIvFs= -github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= -github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= +github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE= +github.com/gorilla/handlers v1.5.2/go.mod h1:dX+xVpaxdSw+q0Qek8SSsl3dfMk3jNddUkMzo0GtH0w= +github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= +github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 h1:Fkzd8ktnpOR9h47SXHe2AYPwelXLH2GjGsjlAloiWfo= github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357/go.mod h1:w9Y7gY31krpLmrVU5ZPG9H7l9fZuRu5/3R3S3FMtVQ4= -github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 h1:YBftPWNWd4WwGqtY2yeZL2ef8rHAxPBD8KFhJpmcqms= -github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0/go.mod h1:YN5jB8ie0yfIUg6VvR9Kz84aCaG7AsGZnLjhHbUqwPg= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 h1:5ZPtiqj0JL5oKWmcsq4VMaAW5ukBEgSGXEN89zeH1Jo= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3/go.mod h1:ndYquD05frm2vACXE1nsccT4oJzjhw2arTS2cpUD1PI= github.com/hack-pad/go-indexeddb v0.3.2 h1:DTqeJJYc1usa45Q5r52t01KhvlSN02+Oq+tQbSBI91A= github.com/hack-pad/go-indexeddb v0.3.2/go.mod h1:QvfTevpDVlkfomY498LhstjwbPW6QC4VC/lxYb0Kom0= github.com/hack-pad/safejs v0.1.0 h1:qPS6vjreAqh2amUqj4WNG1zIw7qlRQJ9K10eDKMCnE8= @@ -274,6 +284,8 @@ github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/b github.com/hashicorp/go-version v1.6.0 h1:feTTfFNnjP967rlCxM/I9g701jU+RN74YKx2mOkIeek= github.com/hashicorp/go-version v1.6.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI= +github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= github.com/ianlancetaylor/demangle v0.0.0-20210905161508-09a460cdf81d/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= @@ -285,6 +297,18 @@ github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8= +github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs= +github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo= +github.com/jcmturner/dnsutils/v2 v2.0.0/go.mod h1:b0TnjGOvI/n42bZa+hmXL+kFJZsFT7G4t3HTlQ184QM= +github.com/jcmturner/gofork v1.7.6 h1:QH0l3hzAU1tfT3rZCnW5zXl+orbkNMMRGJfdJjHVETg= +github.com/jcmturner/gofork v1.7.6/go.mod h1:1622LH6i/EZqLloHfE7IeZ0uEJwMSUyQ/nDd82IeqRo= +github.com/jcmturner/goidentity/v6 v6.0.1 h1:VKnZd2oEIMorCTsFBnJWbExfNN7yZr3EhJAxwOkZg6o= +github.com/jcmturner/goidentity/v6 v6.0.1/go.mod h1:X1YW3bgtvwAXju7V3LCIMpY0Gbxyjn/mY9zx4tFonSg= +github.com/jcmturner/gokrb5/v8 v8.4.4 h1:x1Sv4HaTpepFkXbt2IkL29DXRf8sOfZXo8eRKh687T8= +github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs= +github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY= +github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc= github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade h1:FmusiCI1wHw+XQbvL9M+1r/C3SPqKrmBaIOYwVfQoDE= github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade/go.mod h1:ZDXo8KHryOWSIqnsb/CiDq7hQUYryCgdVnxbj8tDG7o= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= @@ -295,6 +319,8 @@ github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9Y github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= +github.com/jonboulle/clockwork v0.5.0 h1:Hyh9A8u51kptdkR+cqRpT1EebBwTn1oK9YfGYbdFz6I= +github.com/jonboulle/clockwork v0.5.0/go.mod h1:3mZlmanh0g2NDKO5TWZVJAfofYk64M7XN3SzBPjZF60= github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 h1:YLvr1eE6cdCqjOe972w/cYF+FjW34v27+9Vo5106B4M= github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw= @@ -309,8 +335,11 @@ github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuV github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= @@ -329,9 +358,11 @@ github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae h1:dIZY4ULFcto4tA github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k= github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= +github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU= +github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To= github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= -github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= -github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= +github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw= github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0tpKMbdCXCIfBclan4oHk1Jb+Hrejirg= @@ -344,8 +375,12 @@ github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs= github.com/miekg/dns v1.1.59/go.mod h1:nZpewl5p6IvctfgrckopVx2OlSEHPRO/U4SYkRklrEk= github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws= github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc= +github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw= +github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s= github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4= github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE= +github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zxSIeXaQ= +github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo= github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk= @@ -364,6 +399,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/netbirdio/dex v0.244.0 h1:1GOvi8wnXYassnKGildzNqRHq0RbcfEUw7LKYpKIN7U= +github.com/netbirdio/dex v0.244.0/go.mod h1:STGInJhPcAflrHmDO7vyit2kSq03PdL+8zQPoGALtcU= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= @@ -434,6 +471,7 @@ github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8= github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE= github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc= github.com/pion/turn/v4 v4.1.1/go.mod h1:2123tHk1O++vmjI5VSD0awT50NywDAq5A2NNNU4Jjs8= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/profile v1.7.0 h1:hnbDkaNWPCLMO9wGLdBFTIZvzDrDfBM2072E1S9gJkA= @@ -445,25 +483,26 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU= github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= -github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q= -github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0= -github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= -github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= -github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io= -github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I= -github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= -github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= +github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= +github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= +github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= +github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= +github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs= +github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA= +github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= +github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= github.com/quic-go/quic-go v0.49.1 h1:e5JXpUyF0f2uFjckQzD8jTghZrOUK1xxDqqZhlwixo0= github.com/quic-go/quic-go v0.49.1/go.mod h1:s2wDnmCdooUQBmQfpUSTCYBl1/D4FcqbULMMkASvR6s= github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM= github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA= -github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= -github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/rs/cors v1.8.0 h1:P2KMzcFwrPoSjkF1WLRPsp3UMLyql8L4v9hQpVeK5so= github.com/rs/cors v1.8.0/go.mod h1:EBwu+T5AvHOcXwvZIkQFjUN6s8Czyqw12GL/Y0tUyRM= github.com/rs/xid v1.3.0 h1:6NjYksEUlhurdVehpc7S7dk6DAmcKv8V9gG0FsVN2U4= github.com/rs/xid v1.3.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/russellhaering/goxmldsig v1.5.0 h1:AU2UkkYIUOTyZRbe08XMThaOCelArgvNfYapcmSjBNw= +github.com/russellhaering/goxmldsig v1.5.0/go.mod h1:x98CjQNFJcWfMxeOrMnMKg70lvDP6tE0nTaeUnjXDmk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/rymdport/portal v0.4.2 h1:7jKRSemwlTyVHHrTGgQg7gmNPJs88xkbKcIL3NlcmSU= github.com/rymdport/portal v0.4.2/go.mod h1:kFF4jslnJ8pD5uCi17brj/ODlfIidOxlgUDTO5ncnC4= @@ -473,21 +512,26 @@ github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFt github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ= github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU= github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k= +github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= +github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 h1:JIAuq3EEf9cgbU6AtGPK4CTG3Zf6CKMNqf0MHTggAUA= github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966/go.mod h1:sUM3LWHvSMaG192sy56D9F7CNvL7jUJVXoqM1QKLnog= github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8= github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E= -github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= -github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= -github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= -github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w= +github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s= +github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4XaB0= +github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY= +github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c h1:km8GpoQut05eY3GiYWEedbTT0qnSxrCjsVbb7yKY1KE= github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c/go.mod h1:cNQ3dwVJtS5Hmnjxy6AgTPd0Inb3pW05ftPSX7NZO7Q= github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef h1:Ch6Q+AZUxDBCVqdkI8FSpFyZDtCVBc2VmejdNrm5rRQ= github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef/go.mod h1:nXTWP6+gD5+LUJ8krVhhoeHjvHTutPxMYl5SvkcnJNE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= @@ -499,7 +543,6 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= @@ -553,30 +596,28 @@ github.com/zeebo/blake3 v0.2.3 h1:TFoLXsjeXqRNFxSbk35Dk4YtszE/MQQGK10BH4ptoTg= github.com/zeebo/blake3 v0.2.3/go.mod h1:mjJjZpnsyIVtVgTOSpJ9vmRE4wgDeyt2HU3qXvvKCaQ= github.com/zeebo/pcg v1.0.1 h1:lyqfGeWiv4ahac6ttHs+I5hwtH/+1mrhlCtVNQM2kHo= github.com/zeebo/pcg v1.0.1/go.mod h1:09F0S9iiKrwn9rlI5yjLkmrug154/YRW6KnnXVDM/l4= -go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= -go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= -go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= -go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= -go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 h1:4Pp6oUg3+e/6M4C0A/3kJ2VYa++dsWVTtGgLVj5xtHg= -go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0/go.mod h1:Mjt1i1INqiaoZOMGR1RIUJN+i3ChKoFRqzrRQhlkbs0= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 h1:Xs2Ncz0gNihqu9iosIZ5SkBbWo5T8JhhLJFMQL1qmLI= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0/go.mod h1:vy+2G/6NvVMpwGX/NyLqcC41fxepnuKHk16E6IZUcJc= -go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ= -go.opentelemetry.io/otel v1.35.0/go.mod h1:UEqy8Zp11hpkUrL73gSlELM0DupHoiq72dR+Zqel/+Y= +go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= +go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0 h1:q4XOmH/0opmeuJtPsbFNivyl7bCt7yRBbeEm2sC/XtQ= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0/go.mod h1:snMWehoOh2wsEwnvvwtDyFCxVeDAODenXHtn5vzrKjo= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 h1:F7Jx+6hwnZ41NSFTO5q4LYDtJRXBf2PD0rNBkeB/lus= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0/go.mod h1:UHB22Z8QsdRDrnAtX4PntOl36ajSxcdUMt1sF7Y6E7Q= +go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8= +go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0/go.mod h1:IPtUMKL4O3tH5y+iXVyAXqpAwMuzC1IrxVS81rummfE= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU= go.opentelemetry.io/otel/exporters/prometheus v0.48.0 h1:sBQe3VNGUjY9IKWQC6z2lNqa5iGbDSxhs60ABwK4y0s= go.opentelemetry.io/otel/exporters/prometheus v0.48.0/go.mod h1:DtrbMzoZWwQHyrQmCfLam5DZbnmorsGbOtTbYHycU5o= -go.opentelemetry.io/otel/metric v1.35.0 h1:0znxYu2SNyuMSQT4Y9WDWej0VpcsxkuklLa4/siN90M= -go.opentelemetry.io/otel/metric v1.35.0/go.mod h1:nKVFgxBZ2fReX6IlyW28MgZojkoAkJGaE8CpgeAU3oE= -go.opentelemetry.io/otel/sdk v1.35.0 h1:iPctf8iprVySXSKJffSS79eOjl9pvxV9ZqOWT0QejKY= -go.opentelemetry.io/otel/sdk v1.35.0/go.mod h1:+ga1bZliga3DxJ3CQGg3updiaAJoNECOgJREo9KHGQg= -go.opentelemetry.io/otel/sdk/metric v1.35.0 h1:1RriWBmCKgkeHEhM7a2uMjMUfP7MsOF5JpUCaEqEI9o= -go.opentelemetry.io/otel/sdk/metric v1.35.0/go.mod h1:is6XYCUMpcKi+ZsOvfluY5YstFnhW0BidkR+gL+qN+w= -go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt/xgMs= -go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc= +go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA= +go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI= +go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E= +go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg= +go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM= +go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA= +go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE= +go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I= go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= @@ -587,6 +628,8 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= +go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= +go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= goauthentik.io/api/v3 v3.2023051.3 h1:NebAhD/TeTWNo/9X3/Uj+rM5fG1HaiLOlKTNLQv9Qq4= goauthentik.io/api/v3 v3.2023051.3/go.mod h1:nYECml4jGbp/541hj8GcylKQG1gVBsKppHy4+7G8u4U= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= @@ -600,16 +643,12 @@ golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1m golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= -golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= -golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= +golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= golang.org/x/image v0.33.0 h1:LXRZRnv1+zGd5XBUVRFmYEphyyKJjQjCRiOuAP3sZfQ= golang.org/x/image v0.33.0/go.mod h1:DD3OsTYT9chzuzTQt+zMcOlBHgfoKQb1gry8p76Y1sc= -golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= -golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mobile v0.0.0-20251113184115-a159579294ab h1:Iqyc+2zr7aGyLuEadIm0KRJP0Wwt+fhlXLa51Fxf1+Q= golang.org/x/mobile v0.0.0-20251113184115-a159579294ab/go.mod h1:Eq3Nh/5pFSWug2ohiudJ1iyU59SO78QFuh4qTTN++I0= @@ -624,18 +663,13 @@ golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk= golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc= -golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= @@ -649,12 +683,10 @@ golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE= -golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= -golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= +golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw= +golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -665,9 +697,8 @@ golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= -golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -703,8 +734,8 @@ golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= -golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= +golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -717,8 +748,8 @@ golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= -golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU= -golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254= +golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q= +golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -730,15 +761,11 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= -golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= -golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= -golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= -golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= +golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= +golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= -golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= @@ -761,42 +788,33 @@ golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvY golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80= golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE= golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI= -google.golang.org/api v0.177.0 h1:8a0p/BbPa65GlqGWtUKxot4p0TV8OGOfyTjtmkXNXmk= -google.golang.org/api v0.177.0/go.mod h1:srbhue4MLjkjbkux5p3dw/ocYOSZTaIEvf7bCOnFQDw= -google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= +gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= +google.golang.org/api v0.257.0 h1:8Y0lzvHlZps53PEaw+G29SsQIkuKrumGWs9puiexNAA= +google.golang.org/api v0.257.0/go.mod h1:4eJrr+vbVaZSqs7vovFd1Jb/A6ml6iw2e6FBYf3GAO4= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= -google.golang.org/genproto v0.0.0-20240123012728-ef4313101c80 h1:KAeGQVN3M9nD0/bQXnr/ClcEMJ968gUXJQ9pwfSynuQ= -google.golang.org/genproto/googleapis/api v0.0.0-20250324211829-b45e905df463 h1:hE3bRWtU6uceqlh4fhrSnUyjKHMKB9KrTLLG+bc0ddM= -google.golang.org/genproto/googleapis/api v0.0.0-20250324211829-b45e905df463/go.mod h1:U90ffi8eUL9MwPcrJylN5+Mk2v3vuPDptd5yyNUiRR8= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 h1:pFyd6EwwL2TqFf8emdthzeX+gZE1ElRq3iM8pui4KBY= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= -google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= -google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= -google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= -google.golang.org/grpc v1.73.0 h1:VIWSmpI2MegBtTuFt5/JWy2oXxtjJ/e89Z70ImfD2ok= -google.golang.org/grpc v1.73.0/go.mod h1:50sbHOUqWoCQGI8V2HQLJM0B+LMlIUjNSZmow7EVBQc= +google.golang.org/genproto v0.0.0-20250603155806-513f23925822 h1:rHWScKit0gvAPuOnu87KpaYtjK5zBMLcULh7gxkCXu4= +google.golang.org/genproto v0.0.0-20250603155806-513f23925822/go.mod h1:HubltRL7rMh0LfnQPkMH4NPDFEWp0jw3vixw7jEM53s= +google.golang.org/genproto/googleapis/api v0.0.0-20251022142026-3a174f9686a8 h1:mepRgnBZa07I4TRuomDE4sTIYieg/osKmzIf4USdWS4= +google.golang.org/genproto/googleapis/api v0.0.0-20251022142026-3a174f9686a8/go.mod h1:fDMmzKV90WSg1NbozdqrE64fkuTv6mlq2zxo9ad+3yo= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251124214823-79d6a2a48846 h1:Wgl1rcDNThT+Zn47YyCXOXyX/COgMTIdhJ717F0l4xk= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251124214823-79d6a2a48846/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= +google.golang.org/grpc v1.77.0 h1:wVVY6/8cGA6vvffn+wWK5ToddbgdU3d8MNENr4evgXM= +google.golang.org/grpc v1.77.0/go.mod h1:z0BY1iVj0q8E1uSQCjL9cppRj+gnZjzDnzV0dHhrNig= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= -google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= -google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= +google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= +google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= @@ -832,5 +850,3 @@ gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY= gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 h1:qDCwdCWECGnwQSQC01Dpnp09fRHxJs9PbktotUqG+hs= gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1/go.mod h1:8hmigyCdYtw5xJGfQDJzSH5Ju8XEIDBnpyi8+O6GRt8= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/idp/dex/config.go b/idp/dex/config.go new file mode 100644 index 000000000..57f832406 --- /dev/null +++ b/idp/dex/config.go @@ -0,0 +1,301 @@ +package dex + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "log/slog" + "os" + "time" + + "golang.org/x/crypto/bcrypt" + "gopkg.in/yaml.v3" + + "github.com/dexidp/dex/server" + "github.com/dexidp/dex/storage" + "github.com/dexidp/dex/storage/sql" + + "github.com/netbirdio/netbird/idp/dex/web" +) + +// parseDuration parses a duration string (e.g., "6h", "24h", "168h"). +func parseDuration(s string) (time.Duration, error) { + return time.ParseDuration(s) +} + +// YAMLConfig represents the YAML configuration file format (mirrors dex's config format) +type YAMLConfig struct { + Issuer string `yaml:"issuer" json:"issuer"` + Storage Storage `yaml:"storage" json:"storage"` + Web Web `yaml:"web" json:"web"` + GRPC GRPC `yaml:"grpc" json:"grpc"` + OAuth2 OAuth2 `yaml:"oauth2" json:"oauth2"` + Expiry Expiry `yaml:"expiry" json:"expiry"` + Logger Logger `yaml:"logger" json:"logger"` + Frontend Frontend `yaml:"frontend" json:"frontend"` + + // StaticConnectors are user defined connectors specified in the config file + StaticConnectors []Connector `yaml:"connectors" json:"connectors"` + + // StaticClients cause the server to use this list of clients rather than + // querying the storage. Write operations, like creating a client, will fail. + StaticClients []storage.Client `yaml:"staticClients" json:"staticClients"` + + // If enabled, the server will maintain a list of passwords which can be used + // to identify a user. + EnablePasswordDB bool `yaml:"enablePasswordDB" json:"enablePasswordDB"` + + // StaticPasswords cause the server use this list of passwords rather than + // querying the storage. + StaticPasswords []Password `yaml:"staticPasswords" json:"staticPasswords"` +} + +// Web is the config format for the HTTP server. +type Web struct { + HTTP string `yaml:"http" json:"http"` + HTTPS string `yaml:"https" json:"https"` + AllowedOrigins []string `yaml:"allowedOrigins" json:"allowedOrigins"` + AllowedHeaders []string `yaml:"allowedHeaders" json:"allowedHeaders"` +} + +// GRPC is the config for the gRPC API. +type GRPC struct { + Addr string `yaml:"addr" json:"addr"` + TLSCert string `yaml:"tlsCert" json:"tlsCert"` + TLSKey string `yaml:"tlsKey" json:"tlsKey"` + TLSClientCA string `yaml:"tlsClientCA" json:"tlsClientCA"` +} + +// OAuth2 describes enabled OAuth2 extensions. +type OAuth2 struct { + SkipApprovalScreen bool `yaml:"skipApprovalScreen" json:"skipApprovalScreen"` + AlwaysShowLoginScreen bool `yaml:"alwaysShowLoginScreen" json:"alwaysShowLoginScreen"` + PasswordConnector string `yaml:"passwordConnector" json:"passwordConnector"` + ResponseTypes []string `yaml:"responseTypes" json:"responseTypes"` + GrantTypes []string `yaml:"grantTypes" json:"grantTypes"` +} + +// Expiry holds configuration for the validity period of components. +type Expiry struct { + SigningKeys string `yaml:"signingKeys" json:"signingKeys"` + IDTokens string `yaml:"idTokens" json:"idTokens"` + AuthRequests string `yaml:"authRequests" json:"authRequests"` + DeviceRequests string `yaml:"deviceRequests" json:"deviceRequests"` + RefreshTokens RefreshTokensExpiry `yaml:"refreshTokens" json:"refreshTokens"` +} + +// RefreshTokensExpiry holds configuration for refresh token expiry. +type RefreshTokensExpiry struct { + ReuseInterval string `yaml:"reuseInterval" json:"reuseInterval"` + ValidIfNotUsedFor string `yaml:"validIfNotUsedFor" json:"validIfNotUsedFor"` + AbsoluteLifetime string `yaml:"absoluteLifetime" json:"absoluteLifetime"` + DisableRotation bool `yaml:"disableRotation" json:"disableRotation"` +} + +// Logger holds configuration required to customize logging. +type Logger struct { + Level string `yaml:"level" json:"level"` + Format string `yaml:"format" json:"format"` +} + +// Frontend holds the server's frontend templates and assets config. +type Frontend struct { + Dir string `yaml:"dir" json:"dir"` + Theme string `yaml:"theme" json:"theme"` + Issuer string `yaml:"issuer" json:"issuer"` + LogoURL string `yaml:"logoURL" json:"logoURL"` + Extra map[string]string `yaml:"extra" json:"extra"` +} + +// Storage holds app's storage configuration. +type Storage struct { + Type string `yaml:"type" json:"type"` + Config map[string]interface{} `yaml:"config" json:"config"` +} + +// Password represents a static user configuration +type Password storage.Password + +func (p *Password) UnmarshalYAML(node *yaml.Node) error { + var data struct { + Email string `yaml:"email"` + Username string `yaml:"username"` + UserID string `yaml:"userID"` + Hash string `yaml:"hash"` + HashFromEnv string `yaml:"hashFromEnv"` + } + if err := node.Decode(&data); err != nil { + return err + } + *p = Password(storage.Password{ + Email: data.Email, + Username: data.Username, + UserID: data.UserID, + }) + if len(data.Hash) == 0 && len(data.HashFromEnv) > 0 { + data.Hash = os.Getenv(data.HashFromEnv) + } + if len(data.Hash) == 0 { + return fmt.Errorf("no password hash provided for user %s", data.Email) + } + + // If this value is a valid bcrypt, use it. + _, bcryptErr := bcrypt.Cost([]byte(data.Hash)) + if bcryptErr == nil { + p.Hash = []byte(data.Hash) + return nil + } + + // For backwards compatibility try to base64 decode this value. + hashBytes, err := base64.StdEncoding.DecodeString(data.Hash) + if err != nil { + return fmt.Errorf("malformed bcrypt hash: %v", bcryptErr) + } + if _, err := bcrypt.Cost(hashBytes); err != nil { + return fmt.Errorf("malformed bcrypt hash: %v", err) + } + p.Hash = hashBytes + return nil +} + +// Connector is a connector configuration that can unmarshal YAML dynamically. +type Connector struct { + Type string `yaml:"type" json:"type"` + Name string `yaml:"name" json:"name"` + ID string `yaml:"id" json:"id"` + Config map[string]interface{} `yaml:"config" json:"config"` +} + +// ToStorageConnector converts a Connector to storage.Connector type. +func (c *Connector) ToStorageConnector() (storage.Connector, error) { + data, err := json.Marshal(c.Config) + if err != nil { + return storage.Connector{}, fmt.Errorf("failed to marshal connector config: %v", err) + } + + return storage.Connector{ + ID: c.ID, + Type: c.Type, + Name: c.Name, + Config: data, + }, nil +} + +// StorageConfig is a configuration that can create a storage. +type StorageConfig interface { + Open(logger *slog.Logger) (storage.Storage, error) +} + +// OpenStorage opens a storage based on the config +func (s *Storage) OpenStorage(logger *slog.Logger) (storage.Storage, error) { + switch s.Type { + case "sqlite3": + file, _ := s.Config["file"].(string) + if file == "" { + return nil, fmt.Errorf("sqlite3 storage requires 'file' config") + } + return (&sql.SQLite3{File: file}).Open(logger) + default: + return nil, fmt.Errorf("unsupported storage type: %s", s.Type) + } +} + +// Validate validates the configuration +func (c *YAMLConfig) Validate() error { + if c.Issuer == "" { + return fmt.Errorf("no issuer specified in config file") + } + if c.Storage.Type == "" { + return fmt.Errorf("no storage type specified in config file") + } + if c.Web.HTTP == "" && c.Web.HTTPS == "" { + return fmt.Errorf("must supply a HTTP/HTTPS address to listen on") + } + if !c.EnablePasswordDB && len(c.StaticPasswords) != 0 { + return fmt.Errorf("cannot specify static passwords without enabling password db") + } + return nil +} + +// ToServerConfig converts YAMLConfig to dex server.Config +func (c *YAMLConfig) ToServerConfig(stor storage.Storage, logger *slog.Logger) server.Config { + cfg := server.Config{ + Issuer: c.Issuer, + Storage: stor, + Logger: logger, + SkipApprovalScreen: c.OAuth2.SkipApprovalScreen, + AllowedOrigins: c.Web.AllowedOrigins, + AllowedHeaders: c.Web.AllowedHeaders, + Web: server.WebConfig{ + Issuer: c.Frontend.Issuer, + LogoURL: c.Frontend.LogoURL, + Theme: c.Frontend.Theme, + Dir: c.Frontend.Dir, + Extra: c.Frontend.Extra, + }, + } + + // Use embedded NetBird-styled templates if no custom dir specified + if c.Frontend.Dir == "" { + cfg.Web.WebFS = web.FS() + } + + if len(c.OAuth2.ResponseTypes) > 0 { + cfg.SupportedResponseTypes = c.OAuth2.ResponseTypes + } + + // Apply expiry settings + if c.Expiry.SigningKeys != "" { + if d, err := parseDuration(c.Expiry.SigningKeys); err == nil { + cfg.RotateKeysAfter = d + } + } + if c.Expiry.IDTokens != "" { + if d, err := parseDuration(c.Expiry.IDTokens); err == nil { + cfg.IDTokensValidFor = d + } + } + if c.Expiry.AuthRequests != "" { + if d, err := parseDuration(c.Expiry.AuthRequests); err == nil { + cfg.AuthRequestsValidFor = d + } + } + if c.Expiry.DeviceRequests != "" { + if d, err := parseDuration(c.Expiry.DeviceRequests); err == nil { + cfg.DeviceRequestsValidFor = d + } + } + + return cfg +} + +// GetRefreshTokenPolicy creates a RefreshTokenPolicy from the expiry config. +// This should be called after ToServerConfig and the policy set on the config. +func (c *YAMLConfig) GetRefreshTokenPolicy(logger *slog.Logger) (*server.RefreshTokenPolicy, error) { + return server.NewRefreshTokenPolicy( + logger, + c.Expiry.RefreshTokens.DisableRotation, + c.Expiry.RefreshTokens.ValidIfNotUsedFor, + c.Expiry.RefreshTokens.AbsoluteLifetime, + c.Expiry.RefreshTokens.ReuseInterval, + ) +} + +// LoadConfig loads configuration from a YAML file +func LoadConfig(path string) (*YAMLConfig, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("failed to read config file: %w", err) + } + + var cfg YAMLConfig + if err := yaml.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("failed to parse config file: %w", err) + } + + if err := cfg.Validate(); err != nil { + return nil, err + } + + return &cfg, nil +} diff --git a/idp/dex/provider.go b/idp/dex/provider.go new file mode 100644 index 000000000..09713a226 --- /dev/null +++ b/idp/dex/provider.go @@ -0,0 +1,934 @@ +// Package dex provides an embedded Dex OIDC identity provider. +package dex + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "log/slog" + "net" + "net/http" + "os" + "path/filepath" + "strings" + "sync" + "time" + + dexapi "github.com/dexidp/dex/api/v2" + "github.com/dexidp/dex/server" + "github.com/dexidp/dex/storage" + "github.com/dexidp/dex/storage/sql" + "github.com/google/uuid" + "github.com/prometheus/client_golang/prometheus" + "golang.org/x/crypto/bcrypt" + "google.golang.org/grpc" +) + +// Config matches what management/internals/server/server.go expects +type Config struct { + Issuer string + Port int + DataDir string + DevMode bool + + // GRPCAddr is the address for the gRPC API (e.g., ":5557"). Empty disables gRPC. + GRPCAddr string +} + +// Provider wraps a Dex server +type Provider struct { + config *Config + yamlConfig *YAMLConfig + dexServer *server.Server + httpServer *http.Server + listener net.Listener + grpcServer *grpc.Server + grpcListener net.Listener + storage storage.Storage + logger *slog.Logger + mu sync.Mutex + running bool +} + +// NewProvider creates and initializes the Dex server +func NewProvider(ctx context.Context, config *Config) (*Provider, error) { + if config.Issuer == "" { + return nil, fmt.Errorf("issuer is required") + } + if config.Port <= 0 { + return nil, fmt.Errorf("invalid port") + } + if config.DataDir == "" { + return nil, fmt.Errorf("data directory is required") + } + + logger := slog.New(slog.NewTextHandler(os.Stderr, nil)) + + // Ensure data directory exists + if err := os.MkdirAll(config.DataDir, 0700); err != nil { + return nil, fmt.Errorf("failed to create data directory: %w", err) + } + + // Initialize SQLite storage + dbPath := filepath.Join(config.DataDir, "oidc.db") + sqliteConfig := &sql.SQLite3{File: dbPath} + stor, err := sqliteConfig.Open(logger) + if err != nil { + return nil, fmt.Errorf("failed to open storage: %w", err) + } + + // Ensure a local connector exists (for password authentication) + if err := ensureLocalConnector(ctx, stor); err != nil { + stor.Close() + return nil, fmt.Errorf("failed to ensure local connector: %w", err) + } + + // Ensure issuer ends with /oauth2 for proper path mounting + issuer := strings.TrimSuffix(config.Issuer, "/") + if !strings.HasSuffix(issuer, "/oauth2") { + issuer += "/oauth2" + } + + // Build refresh token policy (required to avoid nil pointer panics) + refreshPolicy, err := server.NewRefreshTokenPolicy(logger, false, "", "", "") + if err != nil { + stor.Close() + return nil, fmt.Errorf("failed to create refresh token policy: %w", err) + } + + // Build Dex server config - use Dex's types directly + dexConfig := server.Config{ + Issuer: issuer, + Storage: stor, + SkipApprovalScreen: true, + SupportedResponseTypes: []string{"code"}, + Logger: logger, + PrometheusRegistry: prometheus.NewRegistry(), + RotateKeysAfter: 6 * time.Hour, + IDTokensValidFor: 24 * time.Hour, + RefreshTokenPolicy: refreshPolicy, + Web: server.WebConfig{ + Issuer: "NetBird", + }, + } + + dexSrv, err := server.NewServer(ctx, dexConfig) + if err != nil { + stor.Close() + return nil, fmt.Errorf("failed to create dex server: %w", err) + } + + return &Provider{ + config: config, + dexServer: dexSrv, + storage: stor, + logger: logger, + }, nil +} + +// NewProviderFromYAML creates and initializes the Dex server from a YAMLConfig +func NewProviderFromYAML(ctx context.Context, yamlConfig *YAMLConfig) (*Provider, error) { + logger := slog.New(slog.NewTextHandler(os.Stderr, nil)) + + stor, err := yamlConfig.Storage.OpenStorage(logger) + if err != nil { + return nil, fmt.Errorf("failed to open storage: %w", err) + } + + if err := initializeStorage(ctx, stor, yamlConfig); err != nil { + stor.Close() + return nil, err + } + + dexConfig := buildDexConfig(yamlConfig, stor, logger) + dexConfig.RefreshTokenPolicy, err = yamlConfig.GetRefreshTokenPolicy(logger) + if err != nil { + stor.Close() + return nil, fmt.Errorf("failed to create refresh token policy: %w", err) + } + + dexSrv, err := server.NewServer(ctx, dexConfig) + if err != nil { + stor.Close() + return nil, fmt.Errorf("failed to create dex server: %w", err) + } + + return &Provider{ + config: &Config{Issuer: yamlConfig.Issuer, GRPCAddr: yamlConfig.GRPC.Addr}, + yamlConfig: yamlConfig, + dexServer: dexSrv, + storage: stor, + logger: logger, + }, nil +} + +// initializeStorage sets up connectors, passwords, and clients in storage +func initializeStorage(ctx context.Context, stor storage.Storage, cfg *YAMLConfig) error { + if cfg.EnablePasswordDB { + if err := ensureLocalConnector(ctx, stor); err != nil { + return fmt.Errorf("failed to ensure local connector: %w", err) + } + } + if err := ensureStaticPasswords(ctx, stor, cfg.StaticPasswords); err != nil { + return err + } + if err := ensureStaticClients(ctx, stor, cfg.StaticClients); err != nil { + return err + } + return ensureStaticConnectors(ctx, stor, cfg.StaticConnectors) +} + +// ensureStaticPasswords creates or updates static passwords in storage +func ensureStaticPasswords(ctx context.Context, stor storage.Storage, passwords []Password) error { + for _, pw := range passwords { + existing, err := stor.GetPassword(ctx, pw.Email) + if errors.Is(err, storage.ErrNotFound) { + if err := stor.CreatePassword(ctx, storage.Password(pw)); err != nil { + return fmt.Errorf("failed to create password for %s: %w", pw.Email, err) + } + continue + } + if err != nil { + return fmt.Errorf("failed to get password for %s: %w", pw.Email, err) + } + if string(existing.Hash) != string(pw.Hash) { + if err := stor.UpdatePassword(ctx, pw.Email, func(old storage.Password) (storage.Password, error) { + old.Hash = pw.Hash + old.Username = pw.Username + return old, nil + }); err != nil { + return fmt.Errorf("failed to update password for %s: %w", pw.Email, err) + } + } + } + return nil +} + +// ensureStaticClients creates or updates static clients in storage +func ensureStaticClients(ctx context.Context, stor storage.Storage, clients []storage.Client) error { + for _, client := range clients { + _, err := stor.GetClient(ctx, client.ID) + if errors.Is(err, storage.ErrNotFound) { + if err := stor.CreateClient(ctx, client); err != nil { + return fmt.Errorf("failed to create client %s: %w", client.ID, err) + } + continue + } + if err != nil { + return fmt.Errorf("failed to get client %s: %w", client.ID, err) + } + if err := stor.UpdateClient(ctx, client.ID, func(old storage.Client) (storage.Client, error) { + old.RedirectURIs = client.RedirectURIs + old.Name = client.Name + old.Public = client.Public + return old, nil + }); err != nil { + return fmt.Errorf("failed to update client %s: %w", client.ID, err) + } + } + return nil +} + +// ensureStaticConnectors creates or updates static connectors in storage +func ensureStaticConnectors(ctx context.Context, stor storage.Storage, connectors []Connector) error { + for _, conn := range connectors { + storConn, err := conn.ToStorageConnector() + if err != nil { + return fmt.Errorf("failed to convert connector %s: %w", conn.ID, err) + } + _, err = stor.GetConnector(ctx, conn.ID) + if errors.Is(err, storage.ErrNotFound) { + if err := stor.CreateConnector(ctx, storConn); err != nil { + return fmt.Errorf("failed to create connector %s: %w", conn.ID, err) + } + continue + } + if err != nil { + return fmt.Errorf("failed to get connector %s: %w", conn.ID, err) + } + if err := stor.UpdateConnector(ctx, conn.ID, func(old storage.Connector) (storage.Connector, error) { + old.Name = storConn.Name + old.Config = storConn.Config + return old, nil + }); err != nil { + return fmt.Errorf("failed to update connector %s: %w", conn.ID, err) + } + } + return nil +} + +// buildDexConfig creates a server.Config with defaults applied +func buildDexConfig(yamlConfig *YAMLConfig, stor storage.Storage, logger *slog.Logger) server.Config { + cfg := yamlConfig.ToServerConfig(stor, logger) + cfg.PrometheusRegistry = prometheus.NewRegistry() + if cfg.RotateKeysAfter == 0 { + cfg.RotateKeysAfter = 24 * 30 * time.Hour + } + if cfg.IDTokensValidFor == 0 { + cfg.IDTokensValidFor = 24 * time.Hour + } + if cfg.Web.Issuer == "" { + cfg.Web.Issuer = "NetBird" + } + if len(cfg.SupportedResponseTypes) == 0 { + cfg.SupportedResponseTypes = []string{"code"} + } + return cfg +} + +// Start starts the HTTP server and optionally the gRPC API server +func (p *Provider) Start(_ context.Context) error { + p.mu.Lock() + defer p.mu.Unlock() + + if p.running { + return fmt.Errorf("already running") + } + + // Determine listen address from config + var addr string + if p.yamlConfig != nil { + addr = p.yamlConfig.Web.HTTP + if addr == "" { + addr = p.yamlConfig.Web.HTTPS + } + } else if p.config != nil && p.config.Port > 0 { + addr = fmt.Sprintf(":%d", p.config.Port) + } + if addr == "" { + return fmt.Errorf("no listen address configured") + } + + listener, err := net.Listen("tcp", addr) + if err != nil { + return fmt.Errorf("failed to listen on %s: %w", addr, err) + } + p.listener = listener + + // Mount Dex at /oauth2/ path for reverse proxy compatibility + // Don't strip the prefix - Dex's issuer includes /oauth2 so it expects the full path + mux := http.NewServeMux() + mux.Handle("/oauth2/", p.dexServer) + + p.httpServer = &http.Server{Handler: mux} + p.running = true + + go func() { + if err := p.httpServer.Serve(listener); err != nil && err != http.ErrServerClosed { + p.logger.Error("http server error", "error", err) + } + }() + + // Start gRPC API server if configured + if p.config.GRPCAddr != "" { + if err := p.startGRPCServer(); err != nil { + // Clean up HTTP server on failure + _ = p.httpServer.Close() + _ = p.listener.Close() + return fmt.Errorf("failed to start gRPC server: %w", err) + } + } + + p.logger.Info("HTTP server started", "addr", addr) + return nil +} + +// startGRPCServer starts the gRPC API server using Dex's built-in API +func (p *Provider) startGRPCServer() error { + grpcListener, err := net.Listen("tcp", p.config.GRPCAddr) + if err != nil { + return fmt.Errorf("failed to listen on %s: %w", p.config.GRPCAddr, err) + } + p.grpcListener = grpcListener + + p.grpcServer = grpc.NewServer() + // Use Dex's built-in API server implementation + // server.NewAPI(storage, logger, version, dexServer) + dexapi.RegisterDexServer(p.grpcServer, server.NewAPI(p.storage, p.logger, "netbird-dex", p.dexServer)) + + go func() { + if err := p.grpcServer.Serve(grpcListener); err != nil { + p.logger.Error("grpc server error", "error", err) + } + }() + + p.logger.Info("gRPC API server started", "addr", p.config.GRPCAddr) + return nil +} + +// Stop gracefully shuts down +func (p *Provider) Stop(ctx context.Context) error { + p.mu.Lock() + defer p.mu.Unlock() + + if !p.running { + return nil + } + + var errs []error + + // Stop gRPC server first + if p.grpcServer != nil { + p.grpcServer.GracefulStop() + p.grpcServer = nil + } + if p.grpcListener != nil { + p.grpcListener.Close() + p.grpcListener = nil + } + + if p.httpServer != nil { + if err := p.httpServer.Shutdown(ctx); err != nil { + errs = append(errs, err) + } + } + + // Explicitly close listener as fallback (Shutdown should do this, but be safe) + if p.listener != nil { + if err := p.listener.Close(); err != nil { + // Ignore "use of closed network connection" - expected after Shutdown + if !strings.Contains(err.Error(), "use of closed") { + errs = append(errs, err) + } + } + p.listener = nil + } + + if p.storage != nil { + if err := p.storage.Close(); err != nil { + errs = append(errs, err) + } + } + + p.httpServer = nil + p.running = false + + if len(errs) > 0 { + return fmt.Errorf("shutdown errors: %v", errs) + } + return nil +} + +// EnsureDefaultClients creates dashboard and CLI OAuth clients +// Uses Dex's storage.Client directly - no custom wrappers +func (p *Provider) EnsureDefaultClients(ctx context.Context, dashboardURIs, cliURIs []string) error { + clients := []storage.Client{ + { + ID: "netbird-dashboard", + Name: "NetBird Dashboard", + RedirectURIs: dashboardURIs, + Public: true, + }, + { + ID: "netbird-cli", + Name: "NetBird CLI", + RedirectURIs: cliURIs, + Public: true, + }, + } + + for _, client := range clients { + _, err := p.storage.GetClient(ctx, client.ID) + if err == storage.ErrNotFound { + if err := p.storage.CreateClient(ctx, client); err != nil { + return fmt.Errorf("failed to create client %s: %w", client.ID, err) + } + continue + } + if err != nil { + return fmt.Errorf("failed to get client %s: %w", client.ID, err) + } + // Update if exists + if err := p.storage.UpdateClient(ctx, client.ID, func(old storage.Client) (storage.Client, error) { + old.RedirectURIs = client.RedirectURIs + return old, nil + }); err != nil { + return fmt.Errorf("failed to update client %s: %w", client.ID, err) + } + } + + p.logger.Info("default OIDC clients ensured") + return nil +} + +// Storage returns the underlying Dex storage for direct access +// Users can use storage.Client, storage.Password, storage.Connector directly +func (p *Provider) Storage() storage.Storage { + return p.storage +} + +// Handler returns the Dex server as an http.Handler for embedding in another server. +// The handler expects requests with path prefix "/oauth2/". +func (p *Provider) Handler() http.Handler { + return p.dexServer +} + +// CreateUser creates a new user with the given email, username, and password. +// Returns the encoded user ID in Dex's format (base64-encoded protobuf with connector ID). +func (p *Provider) CreateUser(ctx context.Context, email, username, password string) (string, error) { + hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return "", fmt.Errorf("failed to hash password: %w", err) + } + + userID := uuid.New().String() + err = p.storage.CreatePassword(ctx, storage.Password{ + Email: email, + Username: username, + UserID: userID, + Hash: hash, + }) + if err != nil { + return "", err + } + + // Encode the user ID in Dex's format: base64(protobuf{user_id, connector_id}) + // This matches the format Dex uses in JWT tokens + encodedID := EncodeDexUserID(userID, "local") + return encodedID, nil +} + +// EncodeDexUserID encodes user ID and connector ID into Dex's base64-encoded protobuf format. +// Dex uses this format for the 'sub' claim in JWT tokens. +// Format: base64(protobuf message with field 1 = user_id, field 2 = connector_id) +func EncodeDexUserID(userID, connectorID string) string { + // Manually encode protobuf: field 1 (user_id) and field 2 (connector_id) + // Wire type 2 (length-delimited) for strings + var buf []byte + + // Field 1: user_id (tag = 0x0a = field 1, wire type 2) + buf = append(buf, 0x0a) + buf = append(buf, byte(len(userID))) + buf = append(buf, []byte(userID)...) + + // Field 2: connector_id (tag = 0x12 = field 2, wire type 2) + buf = append(buf, 0x12) + buf = append(buf, byte(len(connectorID))) + buf = append(buf, []byte(connectorID)...) + + return base64.RawStdEncoding.EncodeToString(buf) +} + +// DecodeDexUserID decodes Dex's base64-encoded user ID back to the raw user ID and connector ID. +func DecodeDexUserID(encodedID string) (userID, connectorID string, err error) { + // Try RawStdEncoding first, then StdEncoding (with padding) + buf, err := base64.RawStdEncoding.DecodeString(encodedID) + if err != nil { + buf, err = base64.StdEncoding.DecodeString(encodedID) + if err != nil { + return "", "", fmt.Errorf("failed to decode base64: %w", err) + } + } + + // Parse protobuf manually + i := 0 + for i < len(buf) { + if i >= len(buf) { + break + } + tag := buf[i] + i++ + + fieldNum := tag >> 3 + wireType := tag & 0x07 + + if wireType != 2 { // We only expect length-delimited strings + return "", "", fmt.Errorf("unexpected wire type %d", wireType) + } + + if i >= len(buf) { + return "", "", fmt.Errorf("truncated message") + } + length := int(buf[i]) + i++ + + if i+length > len(buf) { + return "", "", fmt.Errorf("truncated string field") + } + value := string(buf[i : i+length]) + i += length + + switch fieldNum { + case 1: + userID = value + case 2: + connectorID = value + } + } + + return userID, connectorID, nil +} + +// GetUser returns a user by email +func (p *Provider) GetUser(ctx context.Context, email string) (storage.Password, error) { + return p.storage.GetPassword(ctx, email) +} + +// GetUserByID returns a user by user ID. +// The userID can be either an encoded Dex ID (base64 protobuf) or a raw UUID. +// Note: This requires iterating through all users since dex storage doesn't index by userID. +func (p *Provider) GetUserByID(ctx context.Context, userID string) (storage.Password, error) { + // Try to decode the user ID in case it's encoded + rawUserID, _, err := DecodeDexUserID(userID) + if err != nil { + // If decoding fails, assume it's already a raw UUID + rawUserID = userID + } + + users, err := p.storage.ListPasswords(ctx) + if err != nil { + return storage.Password{}, fmt.Errorf("failed to list users: %w", err) + } + for _, user := range users { + if user.UserID == rawUserID { + return user, nil + } + } + return storage.Password{}, storage.ErrNotFound +} + +// DeleteUser removes a user by email +func (p *Provider) DeleteUser(ctx context.Context, email string) error { + return p.storage.DeletePassword(ctx, email) +} + +// ListUsers returns all users +func (p *Provider) ListUsers(ctx context.Context) ([]storage.Password, error) { + return p.storage.ListPasswords(ctx) +} + +// ensureLocalConnector creates a local (password) connector if none exists +func ensureLocalConnector(ctx context.Context, stor storage.Storage) error { + connectors, err := stor.ListConnectors(ctx) + if err != nil { + return fmt.Errorf("failed to list connectors: %w", err) + } + + // If any connector exists, we're good + if len(connectors) > 0 { + return nil + } + + // Create a local connector for password authentication + localConnector := storage.Connector{ + ID: "local", + Type: "local", + Name: "Email", + } + + if err := stor.CreateConnector(ctx, localConnector); err != nil { + return fmt.Errorf("failed to create local connector: %w", err) + } + + return nil +} + +// ConnectorConfig represents the configuration for an identity provider connector +type ConnectorConfig struct { + // ID is the unique identifier for the connector + ID string + // Name is a human-readable name for the connector + Name string + // Type is the connector type (oidc, google, microsoft) + Type string + // Issuer is the OIDC issuer URL (for OIDC-based connectors) + Issuer string + // ClientID is the OAuth2 client ID + ClientID string + // ClientSecret is the OAuth2 client secret + ClientSecret string + // RedirectURI is the OAuth2 redirect URI + RedirectURI string +} + +// CreateConnector creates a new connector in Dex storage. +// It maps the connector config to the appropriate Dex connector type and configuration. +func (p *Provider) CreateConnector(ctx context.Context, cfg *ConnectorConfig) (*ConnectorConfig, error) { + // Fill in the redirect URI if not provided + if cfg.RedirectURI == "" { + cfg.RedirectURI = p.GetRedirectURI() + } + + storageConn, err := p.buildStorageConnector(cfg) + if err != nil { + return nil, fmt.Errorf("failed to build connector: %w", err) + } + + if err := p.storage.CreateConnector(ctx, storageConn); err != nil { + return nil, fmt.Errorf("failed to create connector: %w", err) + } + + p.logger.Info("connector created", "id", cfg.ID, "type", cfg.Type) + return cfg, nil +} + +// GetConnector retrieves a connector by ID from Dex storage. +func (p *Provider) GetConnector(ctx context.Context, id string) (*ConnectorConfig, error) { + conn, err := p.storage.GetConnector(ctx, id) + if err != nil { + if err == storage.ErrNotFound { + return nil, err + } + return nil, fmt.Errorf("failed to get connector: %w", err) + } + + return p.parseStorageConnector(conn) +} + +// ListConnectors returns all connectors from Dex storage (excluding the local connector). +func (p *Provider) ListConnectors(ctx context.Context) ([]*ConnectorConfig, error) { + connectors, err := p.storage.ListConnectors(ctx) + if err != nil { + return nil, fmt.Errorf("failed to list connectors: %w", err) + } + + result := make([]*ConnectorConfig, 0, len(connectors)) + for _, conn := range connectors { + // Skip the local password connector + if conn.ID == "local" && conn.Type == "local" { + continue + } + + cfg, err := p.parseStorageConnector(conn) + if err != nil { + p.logger.Warn("failed to parse connector", "id", conn.ID, "error", err) + continue + } + result = append(result, cfg) + } + + return result, nil +} + +// UpdateConnector updates an existing connector in Dex storage. +func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) error { + storageConn, err := p.buildStorageConnector(cfg) + if err != nil { + return fmt.Errorf("failed to build connector: %w", err) + } + + if err := p.storage.UpdateConnector(ctx, cfg.ID, func(old storage.Connector) (storage.Connector, error) { + return storageConn, nil + }); err != nil { + return fmt.Errorf("failed to update connector: %w", err) + } + + p.logger.Info("connector updated", "id", cfg.ID, "type", cfg.Type) + return nil +} + +// DeleteConnector removes a connector from Dex storage. +func (p *Provider) DeleteConnector(ctx context.Context, id string) error { + // Prevent deletion of the local connector + if id == "local" { + return fmt.Errorf("cannot delete the local password connector") + } + + if err := p.storage.DeleteConnector(ctx, id); err != nil { + return fmt.Errorf("failed to delete connector: %w", err) + } + + p.logger.Info("connector deleted", "id", id) + return nil +} + +// buildStorageConnector creates a storage.Connector from ConnectorConfig. +// It handles the type-specific configuration for each connector type. +func (p *Provider) buildStorageConnector(cfg *ConnectorConfig) (storage.Connector, error) { + redirectURI := p.resolveRedirectURI(cfg.RedirectURI) + + var dexType string + var configData []byte + var err error + + switch cfg.Type { + case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak": + dexType = "oidc" + configData, err = buildOIDCConnectorConfig(cfg, redirectURI) + case "google": + dexType = "google" + configData, err = buildOAuth2ConnectorConfig(cfg, redirectURI) + case "microsoft": + dexType = "microsoft" + configData, err = buildOAuth2ConnectorConfig(cfg, redirectURI) + default: + return storage.Connector{}, fmt.Errorf("unsupported connector type: %s", cfg.Type) + } + if err != nil { + return storage.Connector{}, err + } + + return storage.Connector{ID: cfg.ID, Type: dexType, Name: cfg.Name, Config: configData}, nil +} + +// resolveRedirectURI returns the redirect URI, using a default if not provided +func (p *Provider) resolveRedirectURI(redirectURI string) string { + if redirectURI != "" || p.config == nil { + return redirectURI + } + issuer := strings.TrimSuffix(p.config.Issuer, "/") + if !strings.HasSuffix(issuer, "/oauth2") { + issuer += "/oauth2" + } + return issuer + "/callback" +} + +// buildOIDCConnectorConfig creates config for OIDC-based connectors +func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte, error) { + oidcConfig := map[string]interface{}{ + "issuer": cfg.Issuer, + "clientID": cfg.ClientID, + "clientSecret": cfg.ClientSecret, + "redirectURI": redirectURI, + "scopes": []string{"openid", "profile", "email"}, + } + switch cfg.Type { + case "zitadel": + oidcConfig["getUserInfo"] = true + case "entra": + oidcConfig["insecureSkipEmailVerified"] = true + oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"} + case "okta": + oidcConfig["insecureSkipEmailVerified"] = true + } + return encodeConnectorConfig(oidcConfig) +} + +// buildOAuth2ConnectorConfig creates config for OAuth2 connectors (google, microsoft) +func buildOAuth2ConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte, error) { + return encodeConnectorConfig(map[string]interface{}{ + "clientID": cfg.ClientID, + "clientSecret": cfg.ClientSecret, + "redirectURI": redirectURI, + }) +} + +// parseStorageConnector converts a storage.Connector back to ConnectorConfig. +// It infers the original identity provider type from the Dex connector type and ID. +func (p *Provider) parseStorageConnector(conn storage.Connector) (*ConnectorConfig, error) { + cfg := &ConnectorConfig{ + ID: conn.ID, + Name: conn.Name, + } + + if len(conn.Config) == 0 { + cfg.Type = conn.Type + return cfg, nil + } + + var configMap map[string]interface{} + if err := decodeConnectorConfig(conn.Config, &configMap); err != nil { + return nil, fmt.Errorf("failed to parse connector config: %w", err) + } + + // Extract common fields + if v, ok := configMap["clientID"].(string); ok { + cfg.ClientID = v + } + if v, ok := configMap["clientSecret"].(string); ok { + cfg.ClientSecret = v + } + if v, ok := configMap["redirectURI"].(string); ok { + cfg.RedirectURI = v + } + if v, ok := configMap["issuer"].(string); ok { + cfg.Issuer = v + } + + // Infer the original identity provider type from Dex connector type and ID + cfg.Type = inferIdentityProviderType(conn.Type, conn.ID, configMap) + + return cfg, nil +} + +// inferIdentityProviderType determines the original identity provider type +// based on the Dex connector type, connector ID, and configuration. +func inferIdentityProviderType(dexType, connectorID string, _ map[string]interface{}) string { + if dexType != "oidc" { + return dexType + } + return inferOIDCProviderType(connectorID) +} + +// inferOIDCProviderType infers the specific OIDC provider from connector ID +func inferOIDCProviderType(connectorID string) string { + connectorIDLower := strings.ToLower(connectorID) + for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak"} { + if strings.Contains(connectorIDLower, provider) { + return provider + } + } + return "oidc" +} + +// encodeConnectorConfig serializes connector config to JSON bytes. +func encodeConnectorConfig(config map[string]interface{}) ([]byte, error) { + return json.Marshal(config) +} + +// decodeConnectorConfig deserializes connector config from JSON bytes. +func decodeConnectorConfig(data []byte, v interface{}) error { + return json.Unmarshal(data, v) +} + +// GetRedirectURI returns the default redirect URI for connectors. +func (p *Provider) GetRedirectURI() string { + if p.config == nil { + return "" + } + issuer := strings.TrimSuffix(p.config.Issuer, "/") + if !strings.HasSuffix(issuer, "/oauth2") { + issuer += "/oauth2" + } + return issuer + "/callback" +} + +// GetIssuer returns the OIDC issuer URL. +func (p *Provider) GetIssuer() string { + if p.config == nil { + return "" + } + issuer := strings.TrimSuffix(p.config.Issuer, "/") + if !strings.HasSuffix(issuer, "/oauth2") { + issuer += "/oauth2" + } + return issuer +} + +// GetKeysLocation returns the JWKS endpoint URL for token validation. +func (p *Provider) GetKeysLocation() string { + issuer := p.GetIssuer() + if issuer == "" { + return "" + } + return issuer + "/keys" +} + +// GetTokenEndpoint returns the OAuth2 token endpoint URL. +func (p *Provider) GetTokenEndpoint() string { + issuer := p.GetIssuer() + if issuer == "" { + return "" + } + return issuer + "/token" +} + +// GetDeviceAuthEndpoint returns the OAuth2 device authorization endpoint URL. +func (p *Provider) GetDeviceAuthEndpoint() string { + issuer := p.GetIssuer() + if issuer == "" { + return "" + } + return issuer + "/device/code" +} + +// GetAuthorizationEndpoint returns the OAuth2 authorization endpoint URL. +func (p *Provider) GetAuthorizationEndpoint() string { + issuer := p.GetIssuer() + if issuer == "" { + return "" + } + return issuer + "/auth" +} diff --git a/idp/dex/provider_test.go b/idp/dex/provider_test.go new file mode 100644 index 000000000..bc34e592f --- /dev/null +++ b/idp/dex/provider_test.go @@ -0,0 +1,197 @@ +package dex + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUserCreationFlow(t *testing.T) { + ctx := context.Background() + + // Create a temporary directory for the test + tmpDir, err := os.MkdirTemp("", "dex-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + // Create provider with minimal config + config := &Config{ + Issuer: "http://localhost:5556/dex", + Port: 5556, + DataDir: tmpDir, + } + + provider, err := NewProvider(ctx, config) + require.NoError(t, err) + defer func() { _ = provider.Stop(ctx) }() + + // Test user data + email := "test@example.com" + username := "testuser" + password := "testpassword123" + + // Create the user + encodedID, err := provider.CreateUser(ctx, email, username, password) + require.NoError(t, err) + require.NotEmpty(t, encodedID) + + t.Logf("Created user with encoded ID: %s", encodedID) + + // Verify the encoded ID can be decoded + rawUserID, connectorID, err := DecodeDexUserID(encodedID) + require.NoError(t, err) + assert.NotEmpty(t, rawUserID) + assert.Equal(t, "local", connectorID) + + t.Logf("Decoded: rawUserID=%s, connectorID=%s", rawUserID, connectorID) + + // Verify we can look up the user by encoded ID + user, err := provider.GetUserByID(ctx, encodedID) + require.NoError(t, err) + assert.Equal(t, email, user.Email) + assert.Equal(t, username, user.Username) + assert.Equal(t, rawUserID, user.UserID) + + // Verify we can also look up by raw UUID (backwards compatibility) + user2, err := provider.GetUserByID(ctx, rawUserID) + require.NoError(t, err) + assert.Equal(t, email, user2.Email) + + // Verify we can look up by email + user3, err := provider.GetUser(ctx, email) + require.NoError(t, err) + assert.Equal(t, rawUserID, user3.UserID) + + // Verify encoding produces consistent format + reEncodedID := EncodeDexUserID(rawUserID, "local") + assert.Equal(t, encodedID, reEncodedID) +} + +func TestDecodeDexUserID(t *testing.T) { + tests := []struct { + name string + encodedID string + wantUserID string + wantConnID string + wantErr bool + }{ + { + name: "valid encoded ID", + encodedID: "CiQ3YWFkOGMwNS0zMjg3LTQ3M2YtYjQyYS0zNjU1MDRiZjI1ZTcSBWxvY2Fs", + wantUserID: "7aad8c05-3287-473f-b42a-365504bf25e7", + wantConnID: "local", + wantErr: false, + }, + { + name: "invalid base64", + encodedID: "not-valid-base64!!!", + wantUserID: "", + wantConnID: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + userID, connID, err := DecodeDexUserID(tt.encodedID) + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantUserID, userID) + assert.Equal(t, tt.wantConnID, connID) + }) + } +} + +func TestEncodeDexUserID(t *testing.T) { + userID := "7aad8c05-3287-473f-b42a-365504bf25e7" + connectorID := "local" + + encoded := EncodeDexUserID(userID, connectorID) + assert.NotEmpty(t, encoded) + + // Verify round-trip + decodedUserID, decodedConnID, err := DecodeDexUserID(encoded) + require.NoError(t, err) + assert.Equal(t, userID, decodedUserID) + assert.Equal(t, connectorID, decodedConnID) +} + +func TestEncodeDexUserID_MatchesDexFormat(t *testing.T) { + // This is an actual ID from Dex - verify our encoding matches + knownEncodedID := "CiQ3YWFkOGMwNS0zMjg3LTQ3M2YtYjQyYS0zNjU1MDRiZjI1ZTcSBWxvY2Fs" + knownUserID := "7aad8c05-3287-473f-b42a-365504bf25e7" + knownConnectorID := "local" + + // Decode the known ID + userID, connID, err := DecodeDexUserID(knownEncodedID) + require.NoError(t, err) + assert.Equal(t, knownUserID, userID) + assert.Equal(t, knownConnectorID, connID) + + // Re-encode and verify it matches + reEncoded := EncodeDexUserID(knownUserID, knownConnectorID) + assert.Equal(t, knownEncodedID, reEncoded) +} + +func TestCreateUserInTempDB(t *testing.T) { + ctx := context.Background() + + // Create temp directory + tmpDir, err := os.MkdirTemp("", "dex-create-user-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + // Create YAML config for the test + yamlContent := ` +issuer: http://localhost:5556/dex +storage: + type: sqlite3 + config: + file: ` + filepath.Join(tmpDir, "dex.db") + ` +web: + http: 127.0.0.1:5556 +enablePasswordDB: true +` + configPath := filepath.Join(tmpDir, "config.yaml") + err = os.WriteFile(configPath, []byte(yamlContent), 0644) + require.NoError(t, err) + + // Load config and create provider + yamlConfig, err := LoadConfig(configPath) + require.NoError(t, err) + + provider, err := NewProviderFromYAML(ctx, yamlConfig) + require.NoError(t, err) + defer func() { _ = provider.Stop(ctx) }() + + // Create user + email := "newuser@example.com" + username := "newuser" + password := "securepassword123" + + encodedID, err := provider.CreateUser(ctx, email, username, password) + require.NoError(t, err) + + t.Logf("Created user: email=%s, encodedID=%s", email, encodedID) + + // Verify lookup works with encoded ID + user, err := provider.GetUserByID(ctx, encodedID) + require.NoError(t, err) + assert.Equal(t, email, user.Email) + assert.Equal(t, username, user.Username) + + // Decode and verify format + rawID, connID, err := DecodeDexUserID(encodedID) + require.NoError(t, err) + assert.Equal(t, "local", connID) + assert.Equal(t, rawID, user.UserID) + + t.Logf("User lookup successful: rawID=%s, connectorID=%s", rawID, connID) +} diff --git a/idp/dex/web/robots.txt b/idp/dex/web/robots.txt new file mode 100755 index 000000000..77470cb39 --- /dev/null +++ b/idp/dex/web/robots.txt @@ -0,0 +1,2 @@ +User-agent: * +Disallow: / \ No newline at end of file diff --git a/idp/dex/web/static/main.css b/idp/dex/web/static/main.css new file mode 100755 index 000000000..39302c4c1 --- /dev/null +++ b/idp/dex/web/static/main.css @@ -0,0 +1 @@ +/* NetBird DEX Static CSS - main styles are inline in header.html */ \ No newline at end of file diff --git a/idp/dex/web/templates/approval.html b/idp/dex/web/templates/approval.html new file mode 100755 index 000000000..c84c3b3a0 --- /dev/null +++ b/idp/dex/web/templates/approval.html @@ -0,0 +1,26 @@ +{{ template "header.html" . }} + +
+

Grant Access

+

{{ .Client }} wants to access your account

+ +
+ + + +
+ +
+ +
+ + + +
+
+ +{{ template "footer.html" . }} \ No newline at end of file diff --git a/idp/dex/web/templates/device.html b/idp/dex/web/templates/device.html new file mode 100755 index 000000000..61faa6d53 --- /dev/null +++ b/idp/dex/web/templates/device.html @@ -0,0 +1,34 @@ +{{ template "header.html" . }} + +
+

Device Login

+

Enter the code shown on your device

+ +
+ {{ if .Invalid }} +
+ Invalid user code. +
+ {{ end }} + +
+ + +
+ + +
+
+ +{{ template "footer.html" . }} \ No newline at end of file diff --git a/idp/dex/web/templates/device_success.html b/idp/dex/web/templates/device_success.html new file mode 100755 index 000000000..af1d02031 --- /dev/null +++ b/idp/dex/web/templates/device_success.html @@ -0,0 +1,16 @@ +{{ template "header.html" . }} + +
+
+ + + + +
+

Device Authorized

+

+ Your device has been successfully authorized. You can close this window. +

+
+ +{{ template "footer.html" . }} \ No newline at end of file diff --git a/idp/dex/web/templates/error.html b/idp/dex/web/templates/error.html new file mode 100755 index 000000000..5dc2d190f --- /dev/null +++ b/idp/dex/web/templates/error.html @@ -0,0 +1,16 @@ +{{ template "header.html" . }} + +
+
+ + + + +
+

{{ .ErrType }}

+
+ {{ .ErrMsg }} +
+
+ +{{ template "footer.html" . }} \ No newline at end of file diff --git a/idp/dex/web/templates/footer.html b/idp/dex/web/templates/footer.html new file mode 100755 index 000000000..17c7245b6 --- /dev/null +++ b/idp/dex/web/templates/footer.html @@ -0,0 +1,3 @@ + + + \ No newline at end of file diff --git a/idp/dex/web/templates/header.html b/idp/dex/web/templates/header.html new file mode 100755 index 000000000..5759ee321 --- /dev/null +++ b/idp/dex/web/templates/header.html @@ -0,0 +1,70 @@ + + + + + + {{ issuer }} + + + + + +
+ \ No newline at end of file diff --git a/idp/dex/web/templates/login.html b/idp/dex/web/templates/login.html new file mode 100755 index 000000000..681532d86 --- /dev/null +++ b/idp/dex/web/templates/login.html @@ -0,0 +1,56 @@ +{{ template "header.html" . }} + +
+

Sign in

+

Choose your login method

+ + {{/* First pass: render Email/Local connectors at the top */}} + {{ range $c := .Connectors }} + {{- $nameLower := lower $c.Name -}} + {{- $idLower := lower $c.ID -}} + {{- if or (contains "email" $nameLower) (contains "email" $idLower) (contains "local" $nameLower) (contains "local" $idLower) -}} + + + Continue with {{ $c.Name }} + + {{- end -}} + {{ end }} + + {{/* Second pass: render all other connectors */}} + {{ range $c := .Connectors }} + {{- $nameLower := lower $c.Name -}} + {{- $idLower := lower $c.ID -}} + {{- if not (or (contains "email" $nameLower) (contains "email" $idLower) (contains "local" $nameLower) (contains "local" $idLower)) -}} + + {{- $iconClass := "nb-icon-default" -}} + {{- if or (contains "google" $nameLower) (contains "google" $idLower) -}} + {{- $iconClass = "nb-icon-google" -}} + {{- else if or (contains "github" $nameLower) (contains "github" $idLower) -}} + {{- $iconClass = "nb-icon-github" -}} + {{- else if or (contains "entra" $nameLower) (contains "entra" $idLower) -}} + {{- $iconClass = "nb-icon-entra" -}} + {{- else if or (contains "azure" $nameLower) (contains "azure" $idLower) -}} + {{- $iconClass = "nb-icon-azure" -}} + {{- else if or (contains "microsoft" $nameLower) (contains "microsoft" $idLower) -}} + {{- $iconClass = "nb-icon-microsoft" -}} + {{- else if or (contains "okta" $nameLower) (contains "okta" $idLower) -}} + {{- $iconClass = "nb-icon-okta" -}} + {{- else if or (contains "jumpcloud" $nameLower) (contains "jumpcloud" $idLower) -}} + {{- $iconClass = "nb-icon-jumpcloud" -}} + {{- else if or (contains "pocket" $nameLower) (contains "pocket" $idLower) -}} + {{- $iconClass = "nb-icon-pocketid" -}} + {{- else if or (contains "zitadel" $nameLower) (contains "zitadel" $idLower) -}} + {{- $iconClass = "nb-icon-zitadel" -}} + {{- else if or (contains "authentik" $nameLower) (contains "authentik" $idLower) -}} + {{- $iconClass = "nb-icon-authentik" -}} + {{- else if or (contains "keycloak" $nameLower) (contains "keycloak" $idLower) -}} + {{- $iconClass = "nb-icon-keycloak" -}} + {{- end -}} + + Continue with {{ $c.Name }} + + {{- end -}} + {{ end }} +
+ +{{ template "footer.html" . }} \ No newline at end of file diff --git a/idp/dex/web/templates/oob.html b/idp/dex/web/templates/oob.html new file mode 100755 index 000000000..b887dab61 --- /dev/null +++ b/idp/dex/web/templates/oob.html @@ -0,0 +1,19 @@ +{{ template "header.html" . }} + +
+
+ + + + +
+

Login Successful

+

+ Copy this code back to your application: +

+
+ {{ .Code }} +
+
+ +{{ template "footer.html" . }} \ No newline at end of file diff --git a/idp/dex/web/templates/password.html b/idp/dex/web/templates/password.html new file mode 100755 index 000000000..1d1b8282e --- /dev/null +++ b/idp/dex/web/templates/password.html @@ -0,0 +1,58 @@ +{{ template "header.html" . }} + +
+

Sign in

+

Enter your credentials

+ +
+ {{ if .Invalid }} +
+ Invalid {{ .UsernamePrompt }} or password. +
+ {{ end }} + +
+ + +
+ +
+ + +
+ + +
+ + {{ if .BackLink }} + + {{ end }} +
+ + + +{{ template "footer.html" . }} \ No newline at end of file diff --git a/idp/dex/web/themes/light/favicon.ico b/idp/dex/web/themes/light/favicon.ico new file mode 100644 index 000000000..2bab8a503 Binary files /dev/null and b/idp/dex/web/themes/light/favicon.ico differ diff --git a/idp/dex/web/themes/light/favicon.png b/idp/dex/web/themes/light/favicon.png new file mode 100755 index 000000000..d534ca53d Binary files /dev/null and b/idp/dex/web/themes/light/favicon.png differ diff --git a/idp/dex/web/themes/light/logo.png b/idp/dex/web/themes/light/logo.png new file mode 100755 index 000000000..d534ca53d Binary files /dev/null and b/idp/dex/web/themes/light/logo.png differ diff --git a/idp/dex/web/themes/light/styles.css b/idp/dex/web/themes/light/styles.css new file mode 100755 index 000000000..3033ebd76 --- /dev/null +++ b/idp/dex/web/themes/light/styles.css @@ -0,0 +1 @@ +/* NetBird DEX Theme - styles loaded but CSS is inline in header.html */ \ No newline at end of file diff --git a/idp/dex/web/web.go b/idp/dex/web/web.go new file mode 100644 index 000000000..8cf81392a --- /dev/null +++ b/idp/dex/web/web.go @@ -0,0 +1,14 @@ +package web + +import ( + "embed" + "io/fs" +) + +//go:embed static/* templates/* themes/* robots.txt +var files embed.FS + +// FS returns the embedded web assets filesystem. +func FS() fs.FS { + return files +} diff --git a/idp/sdk/sdk.go b/idp/sdk/sdk.go new file mode 100644 index 000000000..d2189135b --- /dev/null +++ b/idp/sdk/sdk.go @@ -0,0 +1,135 @@ +// Package sdk provides an embeddable SDK for the Dex OIDC identity provider. +package sdk + +import ( + "context" + + "github.com/dexidp/dex/storage" + + "github.com/netbirdio/netbird/idp/dex" +) + +// DexIdP wraps the Dex provider with a builder pattern +type DexIdP struct { + provider *dex.Provider + config *dex.Config + yamlConfig *dex.YAMLConfig +} + +// Option configures a DexIdP instance +type Option func(*dex.Config) + +// WithIssuer sets the OIDC issuer URL +func WithIssuer(issuer string) Option { + return func(c *dex.Config) { c.Issuer = issuer } +} + +// WithPort sets the HTTP port +func WithPort(port int) Option { + return func(c *dex.Config) { c.Port = port } +} + +// WithDataDir sets the data directory for storage +func WithDataDir(dir string) Option { + return func(c *dex.Config) { c.DataDir = dir } +} + +// WithDevMode enables development mode (allows HTTP) +func WithDevMode(dev bool) Option { + return func(c *dex.Config) { c.DevMode = dev } +} + +// WithGRPCAddr sets the gRPC API address +func WithGRPCAddr(addr string) Option { + return func(c *dex.Config) { c.GRPCAddr = addr } +} + +// New creates a new DexIdP instance with the given options +func New(opts ...Option) (*DexIdP, error) { + config := &dex.Config{ + Port: 33081, + DevMode: true, + } + + for _, opt := range opts { + opt(config) + } + + return &DexIdP{config: config}, nil +} + +// NewFromConfigFile creates a new DexIdP instance from a YAML config file +func NewFromConfigFile(path string) (*DexIdP, error) { + yamlConfig, err := dex.LoadConfig(path) + if err != nil { + return nil, err + } + return &DexIdP{yamlConfig: yamlConfig}, nil +} + +// NewFromYAMLConfig creates a new DexIdP instance from a YAMLConfig +func NewFromYAMLConfig(yamlConfig *dex.YAMLConfig) (*DexIdP, error) { + return &DexIdP{yamlConfig: yamlConfig}, nil +} + +// Start initializes and starts the embedded OIDC provider +func (d *DexIdP) Start(ctx context.Context) error { + var err error + if d.yamlConfig != nil { + d.provider, err = dex.NewProviderFromYAML(ctx, d.yamlConfig) + } else { + d.provider, err = dex.NewProvider(ctx, d.config) + } + if err != nil { + return err + } + return d.provider.Start(ctx) +} + +// Stop gracefully shuts down the provider +func (d *DexIdP) Stop(ctx context.Context) error { + if d.provider != nil { + return d.provider.Stop(ctx) + } + return nil +} + +// EnsureDefaultClients creates the default NetBird OAuth clients +func (d *DexIdP) EnsureDefaultClients(ctx context.Context, dashboardURIs, cliURIs []string) error { + return d.provider.EnsureDefaultClients(ctx, dashboardURIs, cliURIs) +} + +// Storage exposes Dex storage for direct user/client/connector management +// Use storage.Client, storage.Password, storage.Connector directly +func (d *DexIdP) Storage() storage.Storage { + return d.provider.Storage() +} + +// CreateUser creates a new user with the given email, username, and password. +// Returns the encoded user ID in Dex's format. +func (d *DexIdP) CreateUser(ctx context.Context, email, username, password string) (string, error) { + return d.provider.CreateUser(ctx, email, username, password) +} + +// DeleteUser removes a user by email +func (d *DexIdP) DeleteUser(ctx context.Context, email string) error { + return d.provider.DeleteUser(ctx, email) +} + +// ListUsers returns all users +func (d *DexIdP) ListUsers(ctx context.Context) ([]storage.Password, error) { + return d.provider.ListUsers(ctx) +} + +// IssuerURL returns the OIDC issuer URL +func (d *DexIdP) IssuerURL() string { + if d.yamlConfig != nil { + return d.yamlConfig.Issuer + } + return d.config.Issuer +} + +// DiscoveryEndpoint returns the OIDC discovery endpoint URL +func (d *DexIdP) DiscoveryEndpoint() string { + return d.IssuerURL() + "/.well-known/openid-configuration" +} diff --git a/infrastructure_files/docker-compose.yml.tmpl b/infrastructure_files/docker-compose.yml.tmpl index 2bc49d3e5..1c9c63f78 100644 --- a/infrastructure_files/docker-compose.yml.tmpl +++ b/infrastructure_files/docker-compose.yml.tmpl @@ -53,7 +53,8 @@ services: command: [ "--cert-file", "$NETBIRD_MGMT_API_CERT_FILE", "--cert-key", "$NETBIRD_MGMT_API_CERT_KEY_FILE", - "--log-file", "console" + "--log-file", "console", + "--port", "80" ] # Relay diff --git a/infrastructure_files/getting-started-with-dex.sh b/infrastructure_files/getting-started-with-dex.sh new file mode 100755 index 000000000..a14c6134e --- /dev/null +++ b/infrastructure_files/getting-started-with-dex.sh @@ -0,0 +1,554 @@ +#!/bin/bash + +set -e + +# NetBird Getting Started with Dex IDP +# This script sets up NetBird with Dex as the identity provider + +# Sed pattern to strip base64 padding characters +SED_STRIP_PADDING='s/=//g' + +check_docker_compose() { + if command -v docker-compose &> /dev/null + then + echo "docker-compose" + return + fi + if docker compose --help &> /dev/null + then + echo "docker compose" + return + fi + + echo "docker-compose is not installed or not in PATH. Please follow the steps from the official guide: https://docs.docker.com/engine/install/" > /dev/stderr + exit 1 +} + +check_jq() { + if ! command -v jq &> /dev/null + then + echo "jq is not installed or not in PATH, please install with your package manager. e.g. sudo apt install jq" > /dev/stderr + exit 1 + fi + return 0 +} + +get_main_ip_address() { + if [[ "$OSTYPE" == "darwin"* ]]; then + interface=$(route -n get default | grep 'interface:' | awk '{print $2}') + ip_address=$(ifconfig "$interface" | grep 'inet ' | awk '{print $2}') + else + interface=$(ip route | grep default | awk '{print $5}' | head -n 1) + ip_address=$(ip addr show "$interface" | grep 'inet ' | awk '{print $2}' | cut -d'/' -f1) + fi + + echo "$ip_address" + return 0 +} + +check_nb_domain() { + DOMAIN=$1 + if [[ "$DOMAIN-x" == "-x" ]]; then + echo "The NETBIRD_DOMAIN variable cannot be empty." > /dev/stderr + return 1 + fi + + if [[ "$DOMAIN" == "netbird.example.com" ]]; then + echo "The NETBIRD_DOMAIN cannot be netbird.example.com" > /dev/stderr + return 1 + fi + return 0 +} + +read_nb_domain() { + READ_NETBIRD_DOMAIN="" + echo -n "Enter the domain you want to use for NetBird (e.g. netbird.my-domain.com): " > /dev/stderr + read -r READ_NETBIRD_DOMAIN < /dev/tty + if ! check_nb_domain "$READ_NETBIRD_DOMAIN"; then + read_nb_domain + fi + echo "$READ_NETBIRD_DOMAIN" + return 0 +} + +get_turn_external_ip() { + TURN_EXTERNAL_IP_CONFIG="#external-ip=" + IP=$(curl -s -4 https://jsonip.com | jq -r '.ip') + if [[ "x-$IP" != "x-" ]]; then + TURN_EXTERNAL_IP_CONFIG="external-ip=$IP" + fi + echo "$TURN_EXTERNAL_IP_CONFIG" + return 0 +} + +wait_dex() { + set +e + echo -n "Waiting for Dex to become ready (via $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN)" + counter=1 + while true; do + # Check Dex through Caddy proxy (also validates TLS is working) + if curl -sk -f -o /dev/null "$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/dex/.well-known/openid-configuration" 2>/dev/null; then + break + fi + if [[ $counter -eq 60 ]]; then + echo "" + echo "Taking too long. Checking logs..." + $DOCKER_COMPOSE_COMMAND logs --tail=20 caddy + $DOCKER_COMPOSE_COMMAND logs --tail=20 dex + fi + echo -n " ." + sleep 2 + counter=$((counter + 1)) + done + echo " done" + set -e + return 0 +} + +init_environment() { + CADDY_SECURE_DOMAIN="" + NETBIRD_PORT=80 + NETBIRD_HTTP_PROTOCOL="http" + NETBIRD_RELAY_PROTO="rel" + TURN_USER="self" + TURN_PASSWORD=$(openssl rand -base64 32 | sed "$SED_STRIP_PADDING") + NETBIRD_RELAY_AUTH_SECRET=$(openssl rand -base64 32 | sed "$SED_STRIP_PADDING") + TURN_MIN_PORT=49152 + TURN_MAX_PORT=65535 + TURN_EXTERNAL_IP_CONFIG=$(get_turn_external_ip) + + # Generate secrets for Dex + DEX_DASHBOARD_CLIENT_SECRET=$(openssl rand -base64 32 | sed "$SED_STRIP_PADDING") + + # Generate admin password + NETBIRD_ADMIN_PASSWORD=$(openssl rand -base64 16 | sed "$SED_STRIP_PADDING") + + if ! check_nb_domain "$NETBIRD_DOMAIN"; then + NETBIRD_DOMAIN=$(read_nb_domain) + fi + + if [[ "$NETBIRD_DOMAIN" == "use-ip" ]]; then + NETBIRD_DOMAIN=$(get_main_ip_address) + else + NETBIRD_PORT=443 + CADDY_SECURE_DOMAIN=", $NETBIRD_DOMAIN:$NETBIRD_PORT" + NETBIRD_HTTP_PROTOCOL="https" + NETBIRD_RELAY_PROTO="rels" + fi + + check_jq + + DOCKER_COMPOSE_COMMAND=$(check_docker_compose) + + if [[ -f dex.yaml ]]; then + echo "Generated files already exist, if you want to reinitialize the environment, please remove them first." + echo "You can use the following commands:" + echo " $DOCKER_COMPOSE_COMMAND down --volumes # to remove all containers and volumes" + echo " rm -f docker-compose.yml Caddyfile dex.yaml dashboard.env turnserver.conf management.json relay.env" + echo "Be aware that this will remove all data from the database, and you will have to reconfigure the dashboard." + exit 1 + fi + + echo Rendering initial files... + render_docker_compose > docker-compose.yml + render_caddyfile > Caddyfile + render_dex_config > dex.yaml + render_dashboard_env > dashboard.env + render_management_json > management.json + render_turn_server_conf > turnserver.conf + render_relay_env > relay.env + + echo -e "\nStarting Dex IDP\n" + $DOCKER_COMPOSE_COMMAND up -d caddy dex + + # Wait for Dex to be ready (through caddy proxy) + sleep 3 + wait_dex + + echo -e "\nStarting NetBird services\n" + $DOCKER_COMPOSE_COMMAND up -d + + echo -e "\nDone!\n" + echo "You can access the NetBird dashboard at $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN" + echo "" + echo "Login with the following credentials:" + echo "Email: admin@$NETBIRD_DOMAIN" | tee .env + echo "Password: $NETBIRD_ADMIN_PASSWORD" | tee -a .env + echo "" + echo "Dex admin UI is not available (Dex has no built-in UI)." + echo "To add more users, edit dex.yaml and restart: $DOCKER_COMPOSE_COMMAND restart dex" + return 0 +} + +render_caddyfile() { + cat < /dev/null; then + ADMIN_PASSWORD_HASH=$(htpasswd -bnBC 10 "" "$NETBIRD_ADMIN_PASSWORD" | tr -d ':\n') + elif command -v python3 &> /dev/null; then + ADMIN_PASSWORD_HASH=$(python3 -c "import bcrypt; print(bcrypt.hashpw('$NETBIRD_ADMIN_PASSWORD'.encode(), bcrypt.gensalt(rounds=10)).decode())" 2>/dev/null || echo "") + fi + + # Fallback to a known hash if we can't generate one + if [[ -z "$ADMIN_PASSWORD_HASH" ]]; then + # This is hash of "password" - user should change it + ADMIN_PASSWORD_HASH='$2a$10$2b2cU8CPhOTaGrs1HRQuAueS7JTT5ZHsHSzYiFPm1leZck7Mc8T4W' + NETBIRD_ADMIN_PASSWORD="password" + echo "Warning: Could not generate password hash. Using default password: password. Please change it in dex.yaml" > /dev/stderr + fi + + cat </dev/null || cat /proc/sys/kernel/random/uuid 2>/dev/null || echo "admin-user-id-001")" + +# Optional: Add external identity provider connectors +# connectors: +# - type: github +# id: github +# name: GitHub +# config: +# clientID: \$GITHUB_CLIENT_ID +# clientSecret: \$GITHUB_CLIENT_SECRET +# redirectURI: $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/dex/callback +# +# - type: ldap +# id: ldap +# name: LDAP +# config: +# host: ldap.example.com:636 +# insecureNoSSL: false +# bindDN: cn=admin,dc=example,dc=com +# bindPW: admin +# userSearch: +# baseDN: ou=users,dc=example,dc=com +# filter: "(objectClass=person)" +# username: uid +# idAttr: uid +# emailAttr: mail +# nameAttr: cn +EOF + return 0 +} + +render_turn_server_conf() { + cat < /dev/null + then + echo "docker-compose" + return + fi + if docker compose --help &> /dev/null + then + echo "docker compose" + return + fi + + echo "docker-compose is not installed or not in PATH. Please follow the steps from the official guide: https://docs.docker.com/engine/install/" > /dev/stderr + exit 1 +} + +check_jq() { + if ! command -v jq &> /dev/null + then + echo "jq is not installed or not in PATH, please install with your package manager. e.g. sudo apt install jq" > /dev/stderr + exit 1 + fi + return 0 +} + +get_main_ip_address() { + if [[ "$OSTYPE" == "darwin"* ]]; then + interface=$(route -n get default | grep 'interface:' | awk '{print $2}') + ip_address=$(ifconfig "$interface" | grep 'inet ' | awk '{print $2}') + else + interface=$(ip route | grep default | awk '{print $5}' | head -n 1) + ip_address=$(ip addr show "$interface" | grep 'inet ' | awk '{print $2}' | cut -d'/' -f1) + fi + + echo "$ip_address" + return 0 +} + +check_nb_domain() { + DOMAIN=$1 + if [[ "$DOMAIN-x" == "-x" ]]; then + echo "The NETBIRD_DOMAIN variable cannot be empty." > /dev/stderr + return 1 + fi + + if [[ "$DOMAIN" == "netbird.example.com" ]]; then + echo "The NETBIRD_DOMAIN cannot be netbird.example.com" > /dev/stderr + return 1 + fi + return 0 +} + +read_nb_domain() { + READ_NETBIRD_DOMAIN="" + echo -n "Enter the domain you want to use for NetBird (e.g. netbird.my-domain.com): " > /dev/stderr + read -r READ_NETBIRD_DOMAIN < /dev/tty + if ! check_nb_domain "$READ_NETBIRD_DOMAIN"; then + read_nb_domain + fi + echo "$READ_NETBIRD_DOMAIN" + return 0 +} + +get_turn_external_ip() { + TURN_EXTERNAL_IP_CONFIG="#external-ip=" + IP=$(curl -s -4 https://jsonip.com | jq -r '.ip') + if [[ "x-$IP" != "x-" ]]; then + TURN_EXTERNAL_IP_CONFIG="external-ip=$IP" + fi + echo "$TURN_EXTERNAL_IP_CONFIG" + return 0 +} + +wait_management() { + set +e + echo -n "Waiting for Management server to become ready" + counter=1 + while true; do + # Check the embedded IdP endpoint + if curl -sk -f -o /dev/null "$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/oauth2/.well-known/openid-configuration" 2>/dev/null; then + break + fi + if [[ $counter -eq 60 ]]; then + echo "" + echo "Taking too long. Checking logs..." + $DOCKER_COMPOSE_COMMAND logs --tail=20 caddy + $DOCKER_COMPOSE_COMMAND logs --tail=20 management + fi + echo -n " ." + sleep 2 + counter=$((counter + 1)) + done + echo " done" + set -e + return 0 +} + +init_environment() { + CADDY_SECURE_DOMAIN="" + NETBIRD_PORT=80 + NETBIRD_HTTP_PROTOCOL="http" + NETBIRD_RELAY_PROTO="rel" + TURN_USER="self" + TURN_PASSWORD=$(openssl rand -base64 32 | sed "$SED_STRIP_PADDING") + NETBIRD_RELAY_AUTH_SECRET=$(openssl rand -base64 32 | sed "$SED_STRIP_PADDING") + # Note: DataStoreEncryptionKey must keep base64 padding (=) for Go's base64.StdEncoding + DATASTORE_ENCRYPTION_KEY=$(openssl rand -base64 32) + TURN_MIN_PORT=49152 + TURN_MAX_PORT=65535 + TURN_EXTERNAL_IP_CONFIG=$(get_turn_external_ip) + + if ! check_nb_domain "$NETBIRD_DOMAIN"; then + NETBIRD_DOMAIN=$(read_nb_domain) + fi + + if [[ "$NETBIRD_DOMAIN" == "use-ip" ]]; then + NETBIRD_DOMAIN=$(get_main_ip_address) + else + NETBIRD_PORT=443 + CADDY_SECURE_DOMAIN=", $NETBIRD_DOMAIN:$NETBIRD_PORT" + NETBIRD_HTTP_PROTOCOL="https" + NETBIRD_RELAY_PROTO="rels" + fi + + check_jq + + DOCKER_COMPOSE_COMMAND=$(check_docker_compose) + + if [[ -f management.json ]]; then + echo "Generated files already exist, if you want to reinitialize the environment, please remove them first." + echo "You can use the following commands:" + echo " $DOCKER_COMPOSE_COMMAND down --volumes # to remove all containers and volumes" + echo " rm -f docker-compose.yml Caddyfile dashboard.env turnserver.conf management.json relay.env" + echo "Be aware that this will remove all data from the database, and you will have to reconfigure the dashboard." + exit 1 + fi + + echo Rendering initial files... + render_docker_compose > docker-compose.yml + render_caddyfile > Caddyfile + render_dashboard_env > dashboard.env + render_management_json > management.json + render_turn_server_conf > turnserver.conf + render_relay_env > relay.env + + echo -e "\nStarting NetBird services\n" + $DOCKER_COMPOSE_COMMAND up -d + + # Wait for management (and embedded IdP) to be ready + sleep 3 + wait_management + + echo -e "\nDone!\n" + echo "You can access the NetBird dashboard at $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN" + echo "Follow the onboarding steps to set up your NetBird instance." + return 0 +} + +render_caddyfile() { + cat < 0 { + audience = audiences[0] // Use the first client ID as the primary audience + } + keysLocation = oauthProvider.GetKeysLocation() + signingKeyRefreshEnabled = true + issuer = oauthProvider.GetIssuer() + userIDClaim = oauthProvider.GetUserIDClaim() + } + return Create(s, func() auth.Manager { return auth.NewManager(s.Store(), - s.Config.HttpConfig.AuthIssuer, - s.Config.HttpConfig.AuthAudience, - s.Config.HttpConfig.AuthKeysLocation, - s.Config.HttpConfig.AuthUserIDClaim, - s.Config.GetAuthAudiences(), - s.Config.HttpConfig.IdpSignKeyRefreshEnabled) + issuer, + audience, + keysLocation, + userIDClaim, + audiences, + signingKeyRefreshEnabled) }) } diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index af9ca5f2d..d179f2b68 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -95,6 +95,17 @@ func (s *BaseServer) IdpManager() idp.Manager { return Create(s, func() idp.Manager { var idpManager idp.Manager var err error + // Use embedded IdP manager if embedded Dex is configured and enabled. + // Legacy IdpManager won't be used anymore even if configured. + if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled { + idpManager, err = idp.NewEmbeddedIdPManager(context.Background(), s.Config.EmbeddedIdP, s.Metrics()) + if err != nil { + log.Fatalf("failed to create embedded IDP manager: %v", err) + } + return idpManager + } + + // Fall back to external IdP manager if s.Config.IdpManagerConfig != nil { idpManager, err = idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics()) if err != nil { @@ -105,6 +116,25 @@ func (s *BaseServer) IdpManager() idp.Manager { }) } +// OAuthConfigProvider is only relevant when we have an embedded IdP manager. Otherwise must be nil +func (s *BaseServer) OAuthConfigProvider() idp.OAuthConfigProvider { + if s.Config.EmbeddedIdP == nil || !s.Config.EmbeddedIdP.Enabled { + return nil + } + + idpManager := s.IdpManager() + if idpManager == nil { + return nil + } + + // Reuse the EmbeddedIdPManager instance from IdpManager + // EmbeddedIdPManager implements both idp.Manager and idp.OAuthConfigProvider + if provider, ok := idpManager.(idp.OAuthConfigProvider); ok { + return provider + } + return nil +} + func (s *BaseServer) GroupsManager() groups.Manager { return Create(s, func() groups.Manager { return groups.NewManager(s.Store(), s.PermissionsManager(), s.AccountManager()) diff --git a/management/internals/server/server.go b/management/internals/server/server.go index a1b144dac..cd8d8e8fb 100644 --- a/management/internals/server/server.go +++ b/management/internals/server/server.go @@ -11,6 +11,7 @@ import ( "time" "github.com/google/uuid" + "github.com/netbirdio/netbird/management/server/idp" log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/metric" "golang.org/x/crypto/acme/autocert" @@ -22,7 +23,6 @@ import ( nbconfig "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/metrics" "github.com/netbirdio/netbird/management/server/store" - "github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util/wsproxy" wsproxyserver "github.com/netbirdio/netbird/util/wsproxy/server" "github.com/netbirdio/netbird/version" @@ -40,7 +40,7 @@ type Server interface { SetContainer(key string, container any) } -// Server holds the HTTP BaseServer instance. +// BaseServer holds the HTTP server instance. // Add any additional fields you need, such as database connections, Config, etc. type BaseServer struct { // Config holds the server configuration @@ -129,6 +129,11 @@ func (s *BaseServer) Start(ctx context.Context) error { if s.Config.IdpManagerConfig != nil && s.Config.IdpManagerConfig.ManagerType != "" { idpManager = s.Config.IdpManagerConfig.ManagerType } + + if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled { + idpManager = metrics.EmbeddedType + } + metricsWorker := metrics.NewWorker(srvCtx, installationID, s.Store(), s.PeersUpdateManager(), idpManager) go metricsWorker.Run(srvCtx) } @@ -144,7 +149,7 @@ func (s *BaseServer) Start(ctx context.Context) error { log.WithContext(srvCtx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String()) } - rootHandler := s.handlerFunc(s.GRPCServer(), s.APIHandler(), s.Metrics().GetMeter()) + rootHandler := s.handlerFunc(srvCtx, s.GRPCServer(), s.APIHandler(), s.Metrics().GetMeter()) switch { case s.certManager != nil: // a call to certManager.Listener() always creates a new listener so we do it once @@ -183,7 +188,7 @@ func (s *BaseServer) Start(ctx context.Context) error { log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", s.listener.Addr().String()) s.serveGRPCWithHTTP(ctx, s.listener, rootHandler, tlsEnabled) - s.update = version.NewUpdate("nb/management") + s.update = version.NewUpdateAndStart("nb/management") s.update.SetDaemonVersion(version.NetbirdVersion()) s.update.SetOnUpdateListener(func() { log.WithContext(ctx).Infof("your management version, \"%s\", is outdated, a new management version is available. Learn more here: https://github.com/netbirdio/netbird/releases", version.NetbirdVersion()) @@ -215,6 +220,10 @@ func (s *BaseServer) Stop() error { if s.update != nil { s.update.StopWatch() } + // Stop embedded IdP if configured + if embeddedIdP, ok := s.IdpManager().(*idp.EmbeddedIdPManager); ok { + _ = embeddedIdP.Stop(ctx) + } select { case <-s.Errors(): @@ -246,11 +255,7 @@ func (s *BaseServer) SetContainer(key string, container any) { log.Tracef("container with key %s set successfully", key) } -func updateMgmtConfig(ctx context.Context, path string, config *nbconfig.Config) error { - return util.DirectWriteJson(ctx, path, config) -} - -func (s *BaseServer) handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler { +func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler { wsProxy := wsproxyserver.New(gRPCHandler, wsproxyserver.WithOTelMeter(meter)) return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { diff --git a/management/internals/shared/grpc/conversion.go b/management/internals/shared/grpc/conversion.go index 2b15fe4b8..455e6bd58 100644 --- a/management/internals/shared/grpc/conversion.go +++ b/management/internals/shared/grpc/conversion.go @@ -6,7 +6,11 @@ import ( "net/url" "strings" + log "github.com/sirupsen/logrus" + integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" + "github.com/netbirdio/netbird/client/ssh/auth" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" @@ -15,6 +19,7 @@ import ( "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/shared/sshauth" ) func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig { @@ -83,15 +88,15 @@ func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken return nbConfig } -func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow) *proto.PeerConfig { +func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, enableSSH bool) *proto.PeerConfig { netmask, _ := network.Net.Mask.Size() fqdn := peer.FQDN(dnsName) sshConfig := &proto.SSHConfig{ - SshEnabled: peer.SSHEnabled, + SshEnabled: peer.SSHEnabled || enableSSH, } - if peer.SSHEnabled { + if sshConfig.SshEnabled { sshConfig.JwtConfig = buildJWTConfig(httpConfig, deviceFlowConfig) } @@ -101,16 +106,20 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set Fqdn: fqdn, RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled, LazyConnectionEnabled: settings.LazyConnectionEnabled, + AutoUpdate: &proto.AutoUpdateSettings{ + Version: settings.AutoUpdateVersion, + }, } } func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse { response := &proto.SyncResponse{ - PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig), + PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH), NetworkMap: &proto.NetworkMap{ - Serial: networkMap.Network.CurrentSerial(), - Routes: toProtocolRoutes(networkMap.Routes), - DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort), + Serial: networkMap.Network.CurrentSerial(), + Routes: toProtocolRoutes(networkMap.Routes), + DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort), + PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH), }, Checks: toProtocolChecks(ctx, checks), } @@ -146,9 +155,45 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb response.NetworkMap.ForwardingRules = forwardingRules } + if networkMap.AuthorizedUsers != nil { + hashedUsers, machineUsers := buildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers) + userIDClaim := auth.DefaultUserIDClaim + if httpConfig != nil && httpConfig.AuthUserIDClaim != "" { + userIDClaim = httpConfig.AuthUserIDClaim + } + response.NetworkMap.SshAuth = &proto.SSHAuth{AuthorizedUsers: hashedUsers, MachineUsers: machineUsers, UserIDClaim: userIDClaim} + } + return response } +func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) { + userIDToIndex := make(map[string]uint32) + var hashedUsers [][]byte + machineUsers := make(map[string]*proto.MachineUserIndexes, len(authorizedUsers)) + + for machineUser, users := range authorizedUsers { + indexes := make([]uint32, 0, len(users)) + for userID := range users { + idx, exists := userIDToIndex[userID] + if !exists { + hash, err := sshauth.HashUserID(userID) + if err != nil { + log.WithContext(ctx).Errorf("failed to hash user id %s: %v", userID, err) + continue + } + idx = uint32(len(hashedUsers)) + userIDToIndex[userID] = idx + hashedUsers = append(hashedUsers, hash[:]) + } + indexes = append(indexes, idx) + } + machineUsers[machineUser] = &proto.MachineUserIndexes{Indexes: indexes} + } + + return hashedUsers, machineUsers +} + func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig { for _, rPeer := range peers { dst = append(dst, &proto.RemotePeerConfig{ @@ -383,9 +428,13 @@ func buildJWTConfig(config *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfi keysLocation = strings.TrimSuffix(issuer, "/") + "/.well-known/jwks.json" } + audience := config.AuthAudience + if config.CLIAuthAudience != "" { + audience = config.CLIAuthAudience + } return &proto.JWTConfig{ Issuer: issuer, - Audience: config.AuthAudience, + Audience: audience, KeysLocation: keysLocation, } } diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 462e2e6eb..801c15158 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -16,6 +16,7 @@ import ( pb "github.com/golang/protobuf/proto" // nolint "github.com/golang/protobuf/ptypes/timestamp" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip" + "github.com/netbirdio/netbird/shared/management/client/common" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc/codes" @@ -24,6 +25,7 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" + "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/store" @@ -69,6 +71,8 @@ type Server struct { networkMapController network_map.Controller + oAuthConfigProvider idp.OAuthConfigProvider + syncSem atomic.Int32 syncLim int32 } @@ -83,6 +87,7 @@ func NewServer( authManager auth.Manager, integratedPeerValidator integrated_validator.IntegratedValidator, networkMapController network_map.Controller, + oAuthConfigProvider idp.OAuthConfigProvider, ) (*Server, error) { if appMetrics != nil { // update gauge based on number of connected peers which is equal to open gRPC streams @@ -119,6 +124,7 @@ func NewServer( blockPeersWithSameConfig: blockPeersWithSameConfig, integratedPeerValidator: integratedPeerValidator, networkMapController: networkMapController, + oAuthConfigProvider: oAuthConfigProvider, loginFilter: newLoginFilter(), @@ -184,8 +190,14 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S realIP := getRealIP(ctx) sRealIP := realIP.String() peerMeta := extractPeerMeta(ctx, syncReq.GetMeta()) + userID, err := s.accountManager.GetUserIDByPeerKey(ctx, peerKey.String()) + if err != nil { + s.syncSem.Add(-1) + return mapError(ctx, err) + } + metahashed := metaHash(peerMeta, sRealIP) - if !s.loginFilter.allowLogin(peerKey.String(), metahashed) { + if userID == "" && !s.loginFilter.allowLogin(peerKey.String(), metahashed) { if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountSyncRequestBlocked() } @@ -270,6 +282,8 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S unlock() unlock = nil + log.WithContext(ctx).Debugf("Sync took %s", time.Since(reqStart)) + s.syncSem.Add(-1) return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv) @@ -559,6 +573,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID) } + log.WithContext(ctx).Debugf("Login took %s", time.Since(reqStart)) }() if loginReq.GetMeta() == nil { @@ -635,7 +650,7 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne // if peer has reached this point then it has logged in loginResp := &proto.LoginResponse{ NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil), - PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow), + PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, netMap.EnableSSH), Checks: toProtocolChecks(ctx, postureChecks), } @@ -752,32 +767,48 @@ func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.Encr return nil, status.Error(codes.InvalidArgument, errMSG) } - if s.config.DeviceAuthorizationFlow == nil || s.config.DeviceAuthorizationFlow.Provider == string(nbconfig.NONE) { - return nil, status.Error(codes.NotFound, "no device authorization flow information available") - } + var flowInfoResp *proto.DeviceAuthorizationFlow - provider, ok := proto.DeviceAuthorizationFlowProvider_value[strings.ToUpper(s.config.DeviceAuthorizationFlow.Provider)] - if !ok { - return nil, status.Errorf(codes.InvalidArgument, "no provider found in the protocol for %s", s.config.DeviceAuthorizationFlow.Provider) - } + // Use embedded IdP configuration if available + if s.oAuthConfigProvider != nil { + flowInfoResp = &proto.DeviceAuthorizationFlow{ + Provider: proto.DeviceAuthorizationFlow_HOSTED, + ProviderConfig: &proto.ProviderConfig{ + ClientID: s.oAuthConfigProvider.GetCLIClientID(), + Audience: s.oAuthConfigProvider.GetCLIClientID(), + DeviceAuthEndpoint: s.oAuthConfigProvider.GetDeviceAuthEndpoint(), + TokenEndpoint: s.oAuthConfigProvider.GetTokenEndpoint(), + Scope: s.oAuthConfigProvider.GetDefaultScopes(), + }, + } + } else { + if s.config.DeviceAuthorizationFlow == nil || s.config.DeviceAuthorizationFlow.Provider == string(nbconfig.NONE) { + return nil, status.Error(codes.NotFound, "no device authorization flow information available") + } - flowInfoResp := &proto.DeviceAuthorizationFlow{ - Provider: proto.DeviceAuthorizationFlowProvider(provider), - ProviderConfig: &proto.ProviderConfig{ - ClientID: s.config.DeviceAuthorizationFlow.ProviderConfig.ClientID, - ClientSecret: s.config.DeviceAuthorizationFlow.ProviderConfig.ClientSecret, - Domain: s.config.DeviceAuthorizationFlow.ProviderConfig.Domain, - Audience: s.config.DeviceAuthorizationFlow.ProviderConfig.Audience, - DeviceAuthEndpoint: s.config.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint, - TokenEndpoint: s.config.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint, - Scope: s.config.DeviceAuthorizationFlow.ProviderConfig.Scope, - UseIDToken: s.config.DeviceAuthorizationFlow.ProviderConfig.UseIDToken, - }, + provider, ok := proto.DeviceAuthorizationFlowProvider_value[strings.ToUpper(s.config.DeviceAuthorizationFlow.Provider)] + if !ok { + return nil, status.Errorf(codes.InvalidArgument, "no provider found in the protocol for %s", s.config.DeviceAuthorizationFlow.Provider) + } + + flowInfoResp = &proto.DeviceAuthorizationFlow{ + Provider: proto.DeviceAuthorizationFlowProvider(provider), + ProviderConfig: &proto.ProviderConfig{ + ClientID: s.config.DeviceAuthorizationFlow.ProviderConfig.ClientID, + ClientSecret: s.config.DeviceAuthorizationFlow.ProviderConfig.ClientSecret, + Domain: s.config.DeviceAuthorizationFlow.ProviderConfig.Domain, + Audience: s.config.DeviceAuthorizationFlow.ProviderConfig.Audience, + DeviceAuthEndpoint: s.config.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint, + TokenEndpoint: s.config.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint, + Scope: s.config.DeviceAuthorizationFlow.ProviderConfig.Scope, + UseIDToken: s.config.DeviceAuthorizationFlow.ProviderConfig.UseIDToken, + }, + } } encryptedResp, err := encryption.EncryptMessage(peerKey, key, flowInfoResp) if err != nil { - return nil, status.Error(codes.Internal, "failed to encrypt no device authorization flow information") + return nil, status.Error(codes.Internal, "failed to encrypt device authorization flow information") } return &proto.EncryptedMessage{ @@ -811,30 +842,47 @@ func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.Encryp return nil, status.Error(codes.InvalidArgument, errMSG) } - if s.config.PKCEAuthorizationFlow == nil { - return nil, status.Error(codes.NotFound, "no pkce authorization flow information available") - } + var initInfoFlow *proto.PKCEAuthorizationFlow - initInfoFlow := &proto.PKCEAuthorizationFlow{ - ProviderConfig: &proto.ProviderConfig{ - Audience: s.config.PKCEAuthorizationFlow.ProviderConfig.Audience, - ClientID: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientID, - ClientSecret: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientSecret, - TokenEndpoint: s.config.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint, - AuthorizationEndpoint: s.config.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint, - Scope: s.config.PKCEAuthorizationFlow.ProviderConfig.Scope, - RedirectURLs: s.config.PKCEAuthorizationFlow.ProviderConfig.RedirectURLs, - UseIDToken: s.config.PKCEAuthorizationFlow.ProviderConfig.UseIDToken, - DisablePromptLogin: s.config.PKCEAuthorizationFlow.ProviderConfig.DisablePromptLogin, - LoginFlag: uint32(s.config.PKCEAuthorizationFlow.ProviderConfig.LoginFlag), - }, + // Use embedded IdP configuration if available + if s.oAuthConfigProvider != nil { + initInfoFlow = &proto.PKCEAuthorizationFlow{ + ProviderConfig: &proto.ProviderConfig{ + Audience: s.oAuthConfigProvider.GetCLIClientID(), + ClientID: s.oAuthConfigProvider.GetCLIClientID(), + TokenEndpoint: s.oAuthConfigProvider.GetTokenEndpoint(), + AuthorizationEndpoint: s.oAuthConfigProvider.GetAuthorizationEndpoint(), + Scope: s.oAuthConfigProvider.GetDefaultScopes(), + RedirectURLs: s.oAuthConfigProvider.GetCLIRedirectURLs(), + LoginFlag: uint32(common.LoginFlagPromptLogin), + }, + } + } else { + if s.config.PKCEAuthorizationFlow == nil { + return nil, status.Error(codes.NotFound, "no pkce authorization flow information available") + } + + initInfoFlow = &proto.PKCEAuthorizationFlow{ + ProviderConfig: &proto.ProviderConfig{ + Audience: s.config.PKCEAuthorizationFlow.ProviderConfig.Audience, + ClientID: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientID, + ClientSecret: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientSecret, + TokenEndpoint: s.config.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint, + AuthorizationEndpoint: s.config.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint, + Scope: s.config.PKCEAuthorizationFlow.ProviderConfig.Scope, + RedirectURLs: s.config.PKCEAuthorizationFlow.ProviderConfig.RedirectURLs, + UseIDToken: s.config.PKCEAuthorizationFlow.ProviderConfig.UseIDToken, + DisablePromptLogin: s.config.PKCEAuthorizationFlow.ProviderConfig.DisablePromptLogin, + LoginFlag: uint32(s.config.PKCEAuthorizationFlow.ProviderConfig.LoginFlag), + }, + } } flowInfoResp := s.integratedPeerValidator.ValidateFlowResponse(ctx, peerKey.String(), initInfoFlow) encryptedResp, err := encryption.EncryptMessage(peerKey, key, flowInfoResp) if err != nil { - return nil, status.Error(codes.Internal, "failed to encrypt no pkce authorization flow information") + return nil, status.Error(codes.Internal, "failed to encrypt pkce authorization flow information") } return &proto.EncryptedMessage{ diff --git a/management/server/account.go b/management/server/account.go index a9becc4b6..29415b038 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -243,7 +243,7 @@ func BuildManager( am.externalCacheManager = nbcache.NewUserDataCache(cacheStore) am.cacheManager = nbcache.NewAccountUserDataCache(am.loadAccount, cacheStore) - if !isNil(am.idpManager) { + if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) { go func() { err := am.warmupIDPCache(ctx, cacheStore) if err != nil { @@ -321,7 +321,8 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled || oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled || - oldSettings.DNSDomain != newSettings.DNSDomain { + oldSettings.DNSDomain != newSettings.DNSDomain || + oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion { updateAccountPeers = true } @@ -332,8 +333,9 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco } } - newSettings.Extra.IntegratedValidatorGroups = oldSettings.Extra.IntegratedValidatorGroups - newSettings.Extra.IntegratedValidator = oldSettings.Extra.IntegratedValidator + if newSettings.Extra == nil { + newSettings.Extra = oldSettings.Extra + } if err = transaction.SaveAccountSettings(ctx, accountID, newSettings); err != nil { return err @@ -360,6 +362,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco am.handleLazyConnectionSettings(ctx, oldSettings, newSettings, userID, accountID) am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID) am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID) + am.handleAutoUpdateVersionSettings(ctx, oldSettings, newSettings, userID, accountID) if err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil { return nil, err } @@ -451,6 +454,14 @@ func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Con } } +func (am *DefaultAccountManager) handleAutoUpdateVersionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) { + if oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountAutoUpdateVersionUpdated, map[string]any{ + "version": newSettings.AutoUpdateVersion, + }) + } +} + func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error { if newSettings.PeerInactivityExpirationEnabled { if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration { @@ -546,7 +557,7 @@ func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx co // newAccount creates a new Account with a generated ID and generated default setup keys. // If ID is already in use (due to collision) we try one more time before returning error -func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (*types.Account, error) { +func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain, email, name string) (*types.Account, error) { for i := 0; i < 2; i++ { accountId := xid.New().String() @@ -557,7 +568,7 @@ func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain log.WithContext(ctx).Warnf("an account with ID already exists, retrying...") continue case statusErr.Type() == status.NotFound: - newAccount := newAccountWithId(ctx, accountId, userID, domain, am.disableDefaultPolicy) + newAccount := newAccountWithId(ctx, accountId, userID, domain, email, name, am.disableDefaultPolicy) am.StoreEvent(ctx, userID, newAccount.Id, accountId, activity.AccountCreated, nil) return newAccount, nil default: @@ -730,23 +741,23 @@ func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID st // If user does have an account, it returns the user's account ID. // If the user doesn't have an account, it creates one using the provided domain. // Returns the account ID or an error if none is found or created. -func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) { - if userID == "" { +func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth auth.UserAuth) (string, error) { + if userAuth.UserId == "" { return "", status.Errorf(status.NotFound, "no valid userID provided") } - accountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userID) + accountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userAuth.UserId) if err != nil { if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { - account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) + acc, err := am.GetOrCreateAccountByUser(ctx, userAuth) if err != nil { - return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID) + return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userAuth.UserId) } - if err = am.addAccountIDToIDPAppMeta(ctx, userID, account.Id); err != nil { + if err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, acc.Id); err != nil { return "", err } - return account.Id, nil + return acc.Id, nil } return "", err } @@ -757,9 +768,19 @@ func isNil(i idp.Manager) bool { return i == nil || reflect.ValueOf(i).IsNil() } +// IsEmbeddedIdp checks if the IDP manager is an embedded IDP (data stored locally in DB). +// When true, user cache should be skipped and data fetched directly from the IDP manager. +func IsEmbeddedIdp(i idp.Manager) bool { + if isNil(i) { + return false + } + _, ok := i.(*idp.EmbeddedIdPManager) + return ok +} + // addAccountIDToIDPAppMeta update user's app metadata in idp manager func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error { - if !isNil(am.idpManager) { + if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) { // user can be nil if it wasn't found (e.g., just created) user, err := am.lookupUserInCache(ctx, userID, accountID) if err != nil { @@ -1005,6 +1026,9 @@ func (am *DefaultAccountManager) isCacheFresh(ctx context.Context, accountUsers } func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accountID, userID string) error { + if IsEmbeddedIdp(am.idpManager) { + return nil + } data, err := am.getAccountFromCache(ctx, accountID, false) if err != nil { return err @@ -1096,7 +1120,7 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai lowerDomain := strings.ToLower(userAuth.Domain) - newAccount, err := am.newAccount(ctx, userAuth.UserId, lowerDomain) + newAccount, err := am.newAccount(ctx, userAuth.UserId, lowerDomain, userAuth.Email, userAuth.Name) if err != nil { return "", err } @@ -1121,7 +1145,7 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai } func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth auth.UserAuth) (string, error) { - newUser := types.NewRegularUser(userAuth.UserId) + newUser := types.NewRegularUser(userAuth.UserId, userAuth.Email, userAuth.Name) newUser.AccountID = domainAccountID settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, domainAccountID) @@ -1304,6 +1328,7 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId) if err != nil { // this is not really possible because we got an account by user ID + log.Errorf("failed to get user by ID %s: %v", userAuth.UserId, err) return "", "", status.Errorf(status.NotFound, "user %s not found", userAuth.UserId) } @@ -1446,21 +1471,19 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth } } - if settings.GroupsPropagationEnabled { - removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, removeOldGroups) - if err != nil { - return err - } + removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, removeOldGroups) + if err != nil { + return err + } - newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, addNewGroups) - if err != nil { - return err - } + newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, addNewGroups) + if err != nil { + return err + } - if removedGroupAffectsPeers || newGroupsAffectsPeers { - log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId) - am.BufferUpdateAccountPeers(ctx, userAuth.AccountId) - } + if removedGroupAffectsPeers || newGroupsAffectsPeers { + log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId) + am.BufferUpdateAccountPeers(ctx, userAuth.AccountId) } return nil @@ -1503,7 +1526,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context } if userAuth.DomainCategory != types.PrivateCategory || !isDomainValid(userAuth.Domain) { - return am.GetAccountIDByUserID(ctx, userAuth.UserId, userAuth.Domain) + return am.GetAccountIDByUserID(ctx, userAuth) } if userAuth.AccountId != "" { @@ -1725,7 +1748,7 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account } // newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id -func newAccountWithId(ctx context.Context, accountID, userID, domain string, disableDefaultPolicy bool) *types.Account { +func newAccountWithId(ctx context.Context, accountID, userID, domain, email, name string, disableDefaultPolicy bool) *types.Account { log.WithContext(ctx).Debugf("creating new account") network := types.NewNetwork() @@ -1735,7 +1758,7 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string, dis setupKeys := map[string]*types.SetupKey{} nameServersGroups := make(map[string]*nbdns.NameServerGroup) - owner := types.NewOwnerUser(userID) + owner := types.NewOwnerUser(userID, email, name) owner.AccountID = accountID users[userID] = owner @@ -2148,3 +2171,7 @@ func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transacti return nil } + +func (am *DefaultAccountManager) GetUserIDByPeerKey(ctx context.Context, peerKey string) (string, error) { + return am.Store.GetUserIDByPeerKey(ctx, store.LockingStrengthNone, peerKey) +} diff --git a/management/server/account/manager.go b/management/server/account/manager.go index b5921ec7a..7680a8464 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -24,7 +24,7 @@ import ( type ExternalCacheManager nbcache.UserDataCache type Manager interface { - GetOrCreateAccountByUser(ctx context.Context, userId, domain string) (*types.Account, error) + GetOrCreateAccountByUser(ctx context.Context, userAuth auth.UserAuth) (*types.Account, error) GetAccount(ctx context.Context, accountID string) (*types.Account, error) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType types.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool, allowExtraDNSLabels bool) (*types.SetupKey, error) @@ -44,7 +44,7 @@ type Manager interface { GetAccountMeta(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) AccountExists(ctx context.Context, accountID string) (bool, error) - GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) + GetAccountIDByUserID(ctx context.Context, userAuth auth.UserAuth) (string, error) GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error) DeleteAccount(ctx context.Context, accountID, userID string) error GetUserByID(ctx context.Context, id string) (*types.User, error) @@ -123,4 +123,10 @@ type Manager interface { UpdateToPrimaryAccount(ctx context.Context, accountId string) error GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) + GetUserIDByPeerKey(ctx context.Context, peerKey string) (string, error) + GetIdentityProvider(ctx context.Context, accountID, idpID, userID string) (*types.IdentityProvider, error) + GetIdentityProviders(ctx context.Context, accountID, userID string) ([]*types.IdentityProvider, error) + CreateIdentityProvider(ctx context.Context, accountID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error) + UpdateIdentityProvider(ctx context.Context, accountID, idpID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error) + DeleteIdentityProvider(ctx context.Context, accountID, idpID, userID string) error } diff --git a/management/server/account_test.go b/management/server/account_test.go index 7f125e3a0..59d6e4928 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -382,7 +382,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { } for _, testCase := range tt { - account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io", false) + account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io", "", "", false) account.UpdateSettings(&testCase.accountSettings) account.Network = network account.Peers = testCase.peers @@ -397,7 +397,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { } customZone := account.GetPeersCustomZone(context.Background(), "netbird.io") - networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) + networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers()) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) } @@ -407,7 +407,7 @@ func TestNewAccount(t *testing.T) { domain := "netbird.io" userId := "account_creator" accountID := "account_id" - account := newAccountWithId(context.Background(), accountID, userId, domain, false) + account := newAccountWithId(context.Background(), accountID, userId, domain, "", "", false) verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId}) } @@ -418,7 +418,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { return } - account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") + account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID, Domain: ""}) if err != nil { t.Fatal(err) } @@ -612,7 +612,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID, err := manager.GetAccountIDByUserID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.Domain) + accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain}) require.NoError(t, err, "create init user failed") initAccount, err := manager.Store.GetAccount(context.Background(), accountID) @@ -649,10 +649,10 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) { userId := "user-id" domain := "test.domain" - _ = newAccountWithId(context.Background(), "", userId, domain, false) + _ = newAccountWithId(context.Background(), "", userId, domain, "", "", false) manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain) + accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userId, Domain: domain}) require.NoError(t, err, "create init user failed") // as initAccount was created without account id we have to take the id after account initialization // that happens inside the GetAccountIDByUserID where the id is getting generated @@ -718,7 +718,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) { } userId := "test_user" - account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, "") + account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userId, Domain: ""}) if err != nil { t.Fatal(err) } @@ -745,7 +745,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) { userId := "test_user" domain := "hotmail.com" - account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, domain) + account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userId, Domain: domain}) if err != nil { t.Fatal(err) } @@ -759,7 +759,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) { domain = "gmail.com" - account, err = manager.GetOrCreateAccountByUser(context.Background(), userId, domain) + account, err = manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userId, Domain: domain}) if err != nil { t.Fatalf("got the following error while retrieving existing acc: %v", err) } @@ -782,7 +782,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) { userId := "test_user" - accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userId, Domain: ""}) if err != nil { t.Fatal(err) } @@ -795,14 +795,14 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) { assert.NoError(t, err) assert.True(t, exists, "expected to get existing account after creation using userid") - _, err = manager.GetAccountIDByUserID(context.Background(), "", "") + _, err = manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: "", Domain: ""}) if err == nil { t.Errorf("expected an error when user ID is empty") } } func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*types.Account, error) { - account := newAccountWithId(context.Background(), accountID, userID, domain, false) + account := newAccountWithId(context.Background(), accountID, userID, domain, "", "", false) err := am.Store.SaveAccount(context.Background(), account) if err != nil { return nil, err @@ -1098,7 +1098,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { return } - account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "netbird.cloud") + account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID, Domain: "netbird.cloud"}) if err != nil { t.Fatal(err) } @@ -1849,7 +1849,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) { manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID}) require.NoError(t, err, "unable to create an account") settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthNone, accountID) @@ -1864,7 +1864,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") - _, err = manager.GetAccountIDByUserID(context.Background(), userID, "") + _, err = manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID}) require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1876,7 +1876,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { }, false) require.NoError(t, err, "unable to add peer") - accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID}) require.NoError(t, err, "unable to get the account") err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID) @@ -1920,7 +1920,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID}) require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1946,7 +1946,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. }, } - accountID, err = manager.GetAccountIDByUserID(context.Background(), userID, "") + accountID, err = manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID}) require.NoError(t, err, "unable to get the account") // when we mark peer as connected, the peer login expiration routine should trigger @@ -1963,7 +1963,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") - _, err = manager.GetAccountIDByUserID(context.Background(), userID, "") + _, err = manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID}) require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1975,7 +1975,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test }, false) require.NoError(t, err, "unable to add peer") - accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID}) require.NoError(t, err, "unable to get the account") account, err := manager.Store.GetAccount(context.Background(), accountID) @@ -2025,7 +2025,7 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID}) require.NoError(t, err, "unable to create an account") updatedSettings, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ @@ -3434,7 +3434,7 @@ func TestDefaultAccountManager_IsCacheCold(t *testing.T) { assert.True(t, cold) }) - account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") + account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID}) require.NoError(t, err) t.Run("should return true when account is not found in cache", func(t *testing.T) { @@ -3462,7 +3462,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) { initiatorId := "test-user" domain := "example.com" - account, err := manager.GetOrCreateAccountByUser(ctx, initiatorId, domain) + account, err := manager.GetOrCreateAccountByUser(ctx, auth.UserAuth{UserId: initiatorId, Domain: domain}) require.NoError(t, err) peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId, IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"} @@ -3575,7 +3575,7 @@ func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) { manager, _, err := createManager(t) require.NoError(t, err) - account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") + account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID}) require.NoError(t, err) t.Run("should return account onboarding when onboarding exist", func(t *testing.T) { @@ -3607,7 +3607,7 @@ func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) { manager, _, err := createManager(t) require.NoError(t, err) - account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") + account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID}) require.NoError(t, err) onboarding := &types.AccountOnboarding{ @@ -3646,7 +3646,7 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) { manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID}) require.NoError(t, err, "unable to create an account") key1, err := wgtypes.GenerateKey() @@ -3717,7 +3717,7 @@ func TestAddNewUserToDomainAccountWithApproval(t *testing.T) { // Create a domain-based account with user approval enabled existingAccountID := "existing-account" - account := newAccountWithId(context.Background(), existingAccountID, "owner-user", "example.com", false) + account := newAccountWithId(context.Background(), existingAccountID, "owner-user", "example.com", "", "", false) account.Settings.Extra = &types.ExtraSettings{ UserApprovalRequired: true, } diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 2e3be1ef5..7b939ddff 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -181,6 +181,12 @@ const ( UserRejected Activity = 90 UserCreated Activity = 91 + AccountAutoUpdateVersionUpdated Activity = 92 + + IdentityProviderCreated Activity = 93 + IdentityProviderUpdated Activity = 94 + IdentityProviderDeleted Activity = 95 + AccountDeleted Activity = 99999 ) @@ -287,9 +293,16 @@ var activityMap = map[Activity]Code{ AccountNetworkRangeUpdated: {"Account network range updated", "account.network.range.update"}, PeerIPUpdated: {"Peer IP updated", "peer.ip.update"}, - UserApproved: {"User approved", "user.approve"}, - UserRejected: {"User rejected", "user.reject"}, - UserCreated: {"User created", "user.create"}, + + UserApproved: {"User approved", "user.approve"}, + UserRejected: {"User rejected", "user.reject"}, + UserCreated: {"User created", "user.create"}, + + AccountAutoUpdateVersionUpdated: {"Account AutoUpdate Version updated", "account.settings.auto.version.update"}, + + IdentityProviderCreated: {"Identity provider created", "identityprovider.create"}, + IdentityProviderUpdated: {"Identity provider updated", "identityprovider.update"}, + IdentityProviderDeleted: {"Identity provider deleted", "identityprovider.delete"}, } // StringCode returns a string code of the activity diff --git a/management/server/auth/manager.go b/management/server/auth/manager.go index 0c62357dc..76cc750b6 100644 --- a/management/server/auth/manager.go +++ b/management/server/auth/manager.go @@ -49,8 +49,7 @@ func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim s ) return &manager{ - store: store, - + store: store, validator: jwtValidator, extractor: claimsExtractor, } diff --git a/management/server/dns_test.go b/management/server/dns_test.go index b5e3f2b99..d1da79380 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -277,7 +277,7 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account domain := "example.com" - account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain, false) + account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain, "", "", false) account.Users[dnsRegularUserID] = &types.User{ Id: dnsRegularUserID, diff --git a/management/server/group.go b/management/server/group.go index 84e641f26..9fc8db120 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -427,7 +427,7 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { for _, groupID := range groupIDs { - group, err := transaction.GetGroupByID(ctx, store.LockingStrengthUpdate, accountID, groupID) + group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID) if err != nil { allErrors = errors.Join(allErrors, err) continue @@ -442,6 +442,10 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us deletedGroups = append(deletedGroups, group) } + if len(groupIDsToDelete) == 0 { + return allErrors + } + if err = transaction.DeleteGroups(ctx, accountID, groupIDsToDelete); err != nil { return err } diff --git a/management/server/group_test.go b/management/server/group_test.go index 4935dac5d..95f37a3ff 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -379,7 +379,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t Id: "example user", AutoGroups: []string{groupForUsers.ID}, } - account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain, false) + account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain, "", "", false) account.Routes[routeResource.ID] = routeResource account.Routes[routePeerGroupResource.ID] = routePeerGroupResource account.NameServerGroups[nameServerGroup.ID] = nameServerGroup diff --git a/management/server/http/handler.go b/management/server/http/handler.go index b7c6c113c..bbd6b4750 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -9,6 +9,7 @@ import ( "time" "github.com/gorilla/mux" + idpmanager "github.com/netbirdio/netbird/management/server/idp" "github.com/rs/cors" log "github.com/sirupsen/logrus" @@ -29,6 +30,8 @@ import ( "github.com/netbirdio/netbird/management/server/http/handlers/dns" "github.com/netbirdio/netbird/management/server/http/handlers/events" "github.com/netbirdio/netbird/management/server/http/handlers/groups" + "github.com/netbirdio/netbird/management/server/http/handlers/idp" + "github.com/netbirdio/netbird/management/server/http/handlers/instance" "github.com/netbirdio/netbird/management/server/http/handlers/networks" "github.com/netbirdio/netbird/management/server/http/handlers/peers" "github.com/netbirdio/netbird/management/server/http/handlers/policies" @@ -36,6 +39,8 @@ import ( "github.com/netbirdio/netbird/management/server/http/handlers/setup_keys" "github.com/netbirdio/netbird/management/server/http/handlers/users" "github.com/netbirdio/netbird/management/server/http/middleware" + "github.com/netbirdio/netbird/management/server/http/middleware/bypass" + nbinstance "github.com/netbirdio/netbird/management/server/instance" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" nbnetworks "github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/networks/resources" @@ -51,23 +56,15 @@ const ( ) // NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints. -func NewAPIHandler( - ctx context.Context, - accountManager account.Manager, - networksManager nbnetworks.Manager, - resourceManager resources.Manager, - routerManager routers.Manager, - groupsManager nbgroups.Manager, - LocationManager geolocation.Geolocation, - authManager auth.Manager, - appMetrics telemetry.AppMetrics, - integratedValidator integrated_validator.IntegratedValidator, - proxyController port_forwarding.Controller, - permissionsManager permissions.Manager, - peersManager nbpeers.Manager, - settingsManager settings.Manager, - networkMapController network_map.Controller, -) (http.Handler, error) { +func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager) (http.Handler, error) { + + // Register bypass paths for unauthenticated endpoints + if err := bypass.AddBypassPath("/api/instance"); err != nil { + return nil, fmt.Errorf("failed to add bypass path: %w", err) + } + if err := bypass.AddBypassPath("/api/setup"); err != nil { + return nil, fmt.Errorf("failed to add bypass path: %w", err) + } var rateLimitingConfig *middleware.RateLimiterConfig if os.Getenv(rateLimitingEnabledKey) == "true" { @@ -122,7 +119,14 @@ func NewAPIHandler( return nil, fmt.Errorf("register integrations endpoints: %w", err) } - accounts.AddEndpoints(accountManager, settingsManager, router) + // Check if embedded IdP is enabled + embeddedIdP, embeddedIdpEnabled := idpManager.(*idpmanager.EmbeddedIdPManager) + instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), embeddedIdP) + if err != nil { + return nil, fmt.Errorf("failed to create instance manager: %w", err) + } + + accounts.AddEndpoints(accountManager, settingsManager, embeddedIdpEnabled, router) peers.AddEndpoints(accountManager, router, networkMapController) users.AddEndpoints(accountManager, router) setup_keys.AddEndpoints(accountManager, router) @@ -134,6 +138,13 @@ func NewAPIHandler( dns.AddEndpoints(accountManager, router) events.AddEndpoints(accountManager, router) networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, router) + idp.AddEndpoints(accountManager, router) + instance.AddEndpoints(instanceManager, router) + + // Mount embedded IdP handler at /oauth2 path if configured + if embeddedIdpEnabled { + rootRouter.PathPrefix("/oauth2").Handler(corsMiddleware.Handler(embeddedIdP.Handler())) + } return rootRouter, nil } diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go index f1552d0ea..de778d59a 100644 --- a/management/server/http/handlers/accounts/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -3,12 +3,15 @@ package accounts import ( "context" "encoding/json" + "fmt" "net/http" "net/netip" "time" "github.com/gorilla/mux" + goversion "github.com/hashicorp/go-version" + "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/settings" @@ -26,27 +29,31 @@ const ( // MinNetworkBits is the minimum prefix length for IPv4 network ranges (e.g., /29 gives 8 addresses, /28 gives 16) MinNetworkBitsIPv4 = 28 // MinNetworkBitsIPv6 is the minimum prefix length for IPv6 network ranges - MinNetworkBitsIPv6 = 120 + MinNetworkBitsIPv6 = 120 + disableAutoUpdate = "disabled" + autoUpdateLatestVersion = "latest" ) // handler is a handler that handles the server.Account HTTP endpoints type handler struct { - accountManager account.Manager - settingsManager settings.Manager + accountManager account.Manager + settingsManager settings.Manager + embeddedIdpEnabled bool } -func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, router *mux.Router) { - accountsHandler := newHandler(accountManager, settingsManager) +func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, embeddedIdpEnabled bool, router *mux.Router) { + accountsHandler := newHandler(accountManager, settingsManager, embeddedIdpEnabled) router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS") router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS") router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS") } // newHandler creates a new handler HTTP handler -func newHandler(accountManager account.Manager, settingsManager settings.Manager) *handler { +func newHandler(accountManager account.Manager, settingsManager settings.Manager, embeddedIdpEnabled bool) *handler { return &handler{ - accountManager: accountManager, - settingsManager: settingsManager, + accountManager: accountManager, + settingsManager: settingsManager, + embeddedIdpEnabled: embeddedIdpEnabled, } } @@ -158,10 +165,65 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) { return } - resp := toAccountResponse(accountID, settings, meta, onboarding) + resp := toAccountResponse(accountID, settings, meta, onboarding, h.embeddedIdpEnabled) util.WriteJSONObject(r.Context(), w, []*api.Account{resp}) } +func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJSONRequestBody) (*types.Settings, error) { + returnSettings := &types.Settings{ + PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled, + PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)), + RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked, + + PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled, + PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)), + } + + if req.Settings.Extra != nil { + returnSettings.Extra = &types.ExtraSettings{ + PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled, + UserApprovalRequired: req.Settings.Extra.UserApprovalRequired, + FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled, + FlowGroups: req.Settings.Extra.NetworkTrafficLogsGroups, + FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled, + } + } + + if req.Settings.JwtGroupsEnabled != nil { + returnSettings.JWTGroupsEnabled = *req.Settings.JwtGroupsEnabled + } + if req.Settings.GroupsPropagationEnabled != nil { + returnSettings.GroupsPropagationEnabled = *req.Settings.GroupsPropagationEnabled + } + if req.Settings.JwtGroupsClaimName != nil { + returnSettings.JWTGroupsClaimName = *req.Settings.JwtGroupsClaimName + } + if req.Settings.JwtAllowGroups != nil { + returnSettings.JWTAllowGroups = *req.Settings.JwtAllowGroups + } + if req.Settings.RoutingPeerDnsResolutionEnabled != nil { + returnSettings.RoutingPeerDNSResolutionEnabled = *req.Settings.RoutingPeerDnsResolutionEnabled + } + if req.Settings.DnsDomain != nil { + returnSettings.DNSDomain = *req.Settings.DnsDomain + } + if req.Settings.LazyConnectionEnabled != nil { + returnSettings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled + } + if req.Settings.AutoUpdateVersion != nil { + _, err := goversion.NewSemver(*req.Settings.AutoUpdateVersion) + if *req.Settings.AutoUpdateVersion == autoUpdateLatestVersion || + *req.Settings.AutoUpdateVersion == disableAutoUpdate || + err == nil { + returnSettings.AutoUpdateVersion = *req.Settings.AutoUpdateVersion + } else if *req.Settings.AutoUpdateVersion != "" { + return nil, fmt.Errorf("invalid AutoUpdateVersion") + } + } + + return returnSettings, nil +} + // updateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings) func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) @@ -186,45 +248,10 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { return } - settings := &types.Settings{ - PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled, - PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)), - RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked, - - PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled, - PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)), - } - - if req.Settings.Extra != nil { - settings.Extra = &types.ExtraSettings{ - PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled, - UserApprovalRequired: req.Settings.Extra.UserApprovalRequired, - FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled, - FlowGroups: req.Settings.Extra.NetworkTrafficLogsGroups, - FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled, - } - } - - if req.Settings.JwtGroupsEnabled != nil { - settings.JWTGroupsEnabled = *req.Settings.JwtGroupsEnabled - } - if req.Settings.GroupsPropagationEnabled != nil { - settings.GroupsPropagationEnabled = *req.Settings.GroupsPropagationEnabled - } - if req.Settings.JwtGroupsClaimName != nil { - settings.JWTGroupsClaimName = *req.Settings.JwtGroupsClaimName - } - if req.Settings.JwtAllowGroups != nil { - settings.JWTAllowGroups = *req.Settings.JwtAllowGroups - } - if req.Settings.RoutingPeerDnsResolutionEnabled != nil { - settings.RoutingPeerDNSResolutionEnabled = *req.Settings.RoutingPeerDnsResolutionEnabled - } - if req.Settings.DnsDomain != nil { - settings.DNSDomain = *req.Settings.DnsDomain - } - if req.Settings.LazyConnectionEnabled != nil { - settings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled + settings, err := h.updateAccountRequestSettings(req) + if err != nil { + util.WriteError(r.Context(), err, w) + return } if req.Settings.NetworkRange != nil && *req.Settings.NetworkRange != "" { prefix, err := netip.ParsePrefix(*req.Settings.NetworkRange) @@ -265,7 +292,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { return } - resp := toAccountResponse(accountID, updatedSettings, meta, updatedOnboarding) + resp := toAccountResponse(accountID, updatedSettings, meta, updatedOnboarding, h.embeddedIdpEnabled) util.WriteJSONObject(r.Context(), w, &resp) } @@ -294,7 +321,7 @@ func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta, onboarding *types.AccountOnboarding) *api.Account { +func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta, onboarding *types.AccountOnboarding, embeddedIdpEnabled bool) *api.Account { jwtAllowGroups := settings.JWTAllowGroups if jwtAllowGroups == nil { jwtAllowGroups = []string{} @@ -313,6 +340,8 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A RoutingPeerDnsResolutionEnabled: &settings.RoutingPeerDNSResolutionEnabled, LazyConnectionEnabled: &settings.LazyConnectionEnabled, DnsDomain: &settings.DNSDomain, + AutoUpdateVersion: &settings.AutoUpdateVersion, + EmbeddedIdpEnabled: &embeddedIdpEnabled, } if settings.NetworkRange.IsValid() { diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index c5c48ef32..e455372c8 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -33,6 +33,7 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler { AnyTimes() return &handler{ + embeddedIdpEnabled: false, accountManager: &mock_server.MockAccountManager{ GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) { return account.Settings, nil @@ -121,6 +122,8 @@ func TestAccounts_AccountsHandler(t *testing.T) { RoutingPeerDnsResolutionEnabled: br(false), LazyConnectionEnabled: br(false), DnsDomain: sr(""), + AutoUpdateVersion: sr(""), + EmbeddedIdpEnabled: br(false), }, expectedArray: true, expectedID: accountID, @@ -143,6 +146,32 @@ func TestAccounts_AccountsHandler(t *testing.T) { RoutingPeerDnsResolutionEnabled: br(false), LazyConnectionEnabled: br(false), DnsDomain: sr(""), + AutoUpdateVersion: sr(""), + EmbeddedIdpEnabled: br(false), + }, + expectedArray: false, + expectedID: accountID, + }, + { + name: "PutAccount OK with autoUpdateVersion", + expectedBody: true, + requestType: http.MethodPut, + requestPath: "/api/accounts/" + accountID, + requestBody: bytes.NewBufferString("{\"settings\": {\"auto_update_version\": \"latest\", \"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"), + expectedStatus: http.StatusOK, + expectedSettings: api.AccountSettings{ + PeerLoginExpiration: 15552000, + PeerLoginExpirationEnabled: true, + GroupsPropagationEnabled: br(false), + JwtGroupsClaimName: sr(""), + JwtGroupsEnabled: br(false), + JwtAllowGroups: &[]string{}, + RegularUsersViewBlocked: false, + RoutingPeerDnsResolutionEnabled: br(false), + LazyConnectionEnabled: br(false), + DnsDomain: sr(""), + AutoUpdateVersion: sr("latest"), + EmbeddedIdpEnabled: br(false), }, expectedArray: false, expectedID: accountID, @@ -165,6 +194,8 @@ func TestAccounts_AccountsHandler(t *testing.T) { RoutingPeerDnsResolutionEnabled: br(false), LazyConnectionEnabled: br(false), DnsDomain: sr(""), + AutoUpdateVersion: sr(""), + EmbeddedIdpEnabled: br(false), }, expectedArray: false, expectedID: accountID, @@ -187,6 +218,8 @@ func TestAccounts_AccountsHandler(t *testing.T) { RoutingPeerDnsResolutionEnabled: br(false), LazyConnectionEnabled: br(false), DnsDomain: sr(""), + AutoUpdateVersion: sr(""), + EmbeddedIdpEnabled: br(false), }, expectedArray: false, expectedID: accountID, @@ -209,6 +242,8 @@ func TestAccounts_AccountsHandler(t *testing.T) { RoutingPeerDnsResolutionEnabled: br(false), LazyConnectionEnabled: br(false), DnsDomain: sr(""), + AutoUpdateVersion: sr(""), + EmbeddedIdpEnabled: br(false), }, expectedArray: false, expectedID: accountID, diff --git a/management/server/http/handlers/idp/idp_handler.go b/management/server/http/handlers/idp/idp_handler.go new file mode 100644 index 000000000..077507b89 --- /dev/null +++ b/management/server/http/handlers/idp/idp_handler.go @@ -0,0 +1,196 @@ +package idp + +import ( + "encoding/json" + "net/http" + + "github.com/gorilla/mux" + + "github.com/netbirdio/netbird/management/server/account" + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" + "github.com/netbirdio/netbird/shared/management/status" +) + +// handler handles identity provider HTTP endpoints +type handler struct { + accountManager account.Manager +} + +// AddEndpoints registers identity provider endpoints +func AddEndpoints(accountManager account.Manager, router *mux.Router) { + h := newHandler(accountManager) + router.HandleFunc("/identity-providers", h.getAllIdentityProviders).Methods("GET", "OPTIONS") + router.HandleFunc("/identity-providers", h.createIdentityProvider).Methods("POST", "OPTIONS") + router.HandleFunc("/identity-providers/{idpId}", h.getIdentityProvider).Methods("GET", "OPTIONS") + router.HandleFunc("/identity-providers/{idpId}", h.updateIdentityProvider).Methods("PUT", "OPTIONS") + router.HandleFunc("/identity-providers/{idpId}", h.deleteIdentityProvider).Methods("DELETE", "OPTIONS") +} + +func newHandler(accountManager account.Manager) *handler { + return &handler{ + accountManager: accountManager, + } +} + +// getAllIdentityProviders returns all identity providers for the account +func (h *handler) getAllIdentityProviders(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + accountID, userID := userAuth.AccountId, userAuth.UserId + + providers, err := h.accountManager.GetIdentityProviders(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + response := make([]api.IdentityProvider, 0, len(providers)) + for _, p := range providers { + response = append(response, toAPIResponse(p)) + } + + util.WriteJSONObject(r.Context(), w, response) +} + +// getIdentityProvider returns a specific identity provider +func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + accountID, userID := userAuth.AccountId, userAuth.UserId + + vars := mux.Vars(r) + idpID := vars["idpId"] + if idpID == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "identity provider ID is required"), w) + return + } + + provider, err := h.accountManager.GetIdentityProvider(r.Context(), accountID, idpID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, toAPIResponse(provider)) +} + +// createIdentityProvider creates a new identity provider +func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + accountID, userID := userAuth.AccountId, userAuth.UserId + + var req api.IdentityProviderRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + idp := fromAPIRequest(&req) + + created, err := h.accountManager.CreateIdentityProvider(r.Context(), accountID, userID, idp) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, toAPIResponse(created)) +} + +// updateIdentityProvider updates an existing identity provider +func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + accountID, userID := userAuth.AccountId, userAuth.UserId + + vars := mux.Vars(r) + idpID := vars["idpId"] + if idpID == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "identity provider ID is required"), w) + return + } + + var req api.IdentityProviderRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + idp := fromAPIRequest(&req) + + updated, err := h.accountManager.UpdateIdentityProvider(r.Context(), accountID, idpID, userID, idp) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, toAPIResponse(updated)) +} + +// deleteIdentityProvider deletes an identity provider +func (h *handler) deleteIdentityProvider(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + accountID, userID := userAuth.AccountId, userAuth.UserId + + vars := mux.Vars(r) + idpID := vars["idpId"] + if idpID == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "identity provider ID is required"), w) + return + } + + if err := h.accountManager.DeleteIdentityProvider(r.Context(), accountID, idpID, userID); err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) +} + +func toAPIResponse(idp *types.IdentityProvider) api.IdentityProvider { + resp := api.IdentityProvider{ + Type: api.IdentityProviderType(idp.Type), + Name: idp.Name, + Issuer: idp.Issuer, + ClientId: idp.ClientID, + } + if idp.ID != "" { + resp.Id = &idp.ID + } + // Note: ClientSecret is never returned in responses for security + return resp +} + +func fromAPIRequest(req *api.IdentityProviderRequest) *types.IdentityProvider { + return &types.IdentityProvider{ + Type: types.IdentityProviderType(req.Type), + Name: req.Name, + Issuer: req.Issuer, + ClientID: req.ClientId, + ClientSecret: req.ClientSecret, + } +} diff --git a/management/server/http/handlers/idp/idp_handler_test.go b/management/server/http/handlers/idp/idp_handler_test.go new file mode 100644 index 000000000..74b204048 --- /dev/null +++ b/management/server/http/handlers/idp/idp_handler_test.go @@ -0,0 +1,438 @@ +package idp + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gorilla/mux" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/status" +) + +const ( + testAccountID = "test-account-id" + testUserID = "test-user-id" + existingIDPID = "existing-idp-id" + newIDPID = "new-idp-id" +) + +func initIDPTestData(existingIDP *types.IdentityProvider) *handler { + return &handler{ + accountManager: &mock_server.MockAccountManager{ + GetIdentityProvidersFunc: func(_ context.Context, accountID, userID string) ([]*types.IdentityProvider, error) { + if accountID != testAccountID { + return nil, status.Errorf(status.NotFound, "account not found") + } + if existingIDP != nil { + return []*types.IdentityProvider{existingIDP}, nil + } + return []*types.IdentityProvider{}, nil + }, + GetIdentityProviderFunc: func(_ context.Context, accountID, idpID, userID string) (*types.IdentityProvider, error) { + if accountID != testAccountID { + return nil, status.Errorf(status.NotFound, "account not found") + } + if existingIDP != nil && idpID == existingIDP.ID { + return existingIDP, nil + } + return nil, status.Errorf(status.NotFound, "identity provider not found") + }, + CreateIdentityProviderFunc: func(_ context.Context, accountID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error) { + if accountID != testAccountID { + return nil, status.Errorf(status.NotFound, "account not found") + } + if idp.Name == "" { + return nil, status.Errorf(status.InvalidArgument, "name is required") + } + created := idp.Copy() + created.ID = newIDPID + created.AccountID = accountID + return created, nil + }, + UpdateIdentityProviderFunc: func(_ context.Context, accountID, idpID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error) { + if accountID != testAccountID { + return nil, status.Errorf(status.NotFound, "account not found") + } + if existingIDP == nil || idpID != existingIDP.ID { + return nil, status.Errorf(status.NotFound, "identity provider not found") + } + updated := idp.Copy() + updated.ID = idpID + updated.AccountID = accountID + return updated, nil + }, + DeleteIdentityProviderFunc: func(_ context.Context, accountID, idpID, userID string) error { + if accountID != testAccountID { + return status.Errorf(status.NotFound, "account not found") + } + if existingIDP == nil || idpID != existingIDP.ID { + return status.Errorf(status.NotFound, "identity provider not found") + } + return nil + }, + }, + } +} + +func TestGetAllIdentityProviders(t *testing.T) { + existingIDP := &types.IdentityProvider{ + ID: existingIDPID, + Name: "Test IDP", + Type: types.IdentityProviderTypeOIDC, + Issuer: "https://issuer.example.com", + ClientID: "client-id", + } + + tt := []struct { + name string + expectedStatus int + expectedCount int + }{ + { + name: "Get All Identity Providers", + expectedStatus: http.StatusOK, + expectedCount: 1, + }, + } + + h := initIDPTestData(existingIDP) + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/identity-providers", nil) + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ + UserId: testUserID, + AccountId: testAccountID, + }) + + router := mux.NewRouter() + router.HandleFunc("/api/identity-providers", h.getAllIdentityProviders).Methods("GET") + router.ServeHTTP(recorder, req) + + res := recorder.Result() + defer res.Body.Close() + + assert.Equal(t, tc.expectedStatus, recorder.Code) + + content, err := io.ReadAll(res.Body) + require.NoError(t, err) + + var idps []api.IdentityProvider + err = json.Unmarshal(content, &idps) + require.NoError(t, err) + assert.Len(t, idps, tc.expectedCount) + }) + } +} + +func TestGetIdentityProvider(t *testing.T) { + existingIDP := &types.IdentityProvider{ + ID: existingIDPID, + Name: "Test IDP", + Type: types.IdentityProviderTypeOIDC, + Issuer: "https://issuer.example.com", + ClientID: "client-id", + } + + tt := []struct { + name string + idpID string + expectedStatus int + expectedBody bool + }{ + { + name: "Get Existing Identity Provider", + idpID: existingIDPID, + expectedStatus: http.StatusOK, + expectedBody: true, + }, + { + name: "Get Non-Existing Identity Provider", + idpID: "non-existing-id", + expectedStatus: http.StatusNotFound, + expectedBody: false, + }, + } + + h := initIDPTestData(existingIDP) + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/identity-providers/%s", tc.idpID), nil) + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ + UserId: testUserID, + AccountId: testAccountID, + }) + + router := mux.NewRouter() + router.HandleFunc("/api/identity-providers/{idpId}", h.getIdentityProvider).Methods("GET") + router.ServeHTTP(recorder, req) + + res := recorder.Result() + defer res.Body.Close() + + assert.Equal(t, tc.expectedStatus, recorder.Code) + + if tc.expectedBody { + content, err := io.ReadAll(res.Body) + require.NoError(t, err) + + var idp api.IdentityProvider + err = json.Unmarshal(content, &idp) + require.NoError(t, err) + assert.Equal(t, existingIDPID, *idp.Id) + assert.Equal(t, existingIDP.Name, idp.Name) + } + }) + } +} + +func TestCreateIdentityProvider(t *testing.T) { + tt := []struct { + name string + requestBody string + expectedStatus int + expectedBody bool + }{ + { + name: "Create Identity Provider", + requestBody: `{ + "name": "New IDP", + "type": "oidc", + "issuer": "https://new-issuer.example.com", + "client_id": "new-client-id", + "client_secret": "new-client-secret" + }`, + expectedStatus: http.StatusOK, + expectedBody: true, + }, + { + name: "Create Identity Provider with Invalid JSON", + requestBody: `{invalid json`, + expectedStatus: http.StatusBadRequest, + expectedBody: false, + }, + } + + h := initIDPTestData(nil) + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/identity-providers", bytes.NewBufferString(tc.requestBody)) + req.Header.Set("Content-Type", "application/json") + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ + UserId: testUserID, + AccountId: testAccountID, + }) + + router := mux.NewRouter() + router.HandleFunc("/api/identity-providers", h.createIdentityProvider).Methods("POST") + router.ServeHTTP(recorder, req) + + res := recorder.Result() + defer res.Body.Close() + + assert.Equal(t, tc.expectedStatus, recorder.Code) + + if tc.expectedBody { + content, err := io.ReadAll(res.Body) + require.NoError(t, err) + + var idp api.IdentityProvider + err = json.Unmarshal(content, &idp) + require.NoError(t, err) + assert.Equal(t, newIDPID, *idp.Id) + assert.Equal(t, "New IDP", idp.Name) + assert.Equal(t, api.IdentityProviderTypeOidc, idp.Type) + } + }) + } +} + +func TestUpdateIdentityProvider(t *testing.T) { + existingIDP := &types.IdentityProvider{ + ID: existingIDPID, + Name: "Test IDP", + Type: types.IdentityProviderTypeOIDC, + Issuer: "https://issuer.example.com", + ClientID: "client-id", + ClientSecret: "client-secret", + } + + tt := []struct { + name string + idpID string + requestBody string + expectedStatus int + expectedBody bool + }{ + { + name: "Update Existing Identity Provider", + idpID: existingIDPID, + requestBody: `{ + "name": "Updated IDP", + "type": "oidc", + "issuer": "https://updated-issuer.example.com", + "client_id": "updated-client-id" + }`, + expectedStatus: http.StatusOK, + expectedBody: true, + }, + { + name: "Update Non-Existing Identity Provider", + idpID: "non-existing-id", + requestBody: `{ + "name": "Updated IDP", + "type": "oidc", + "issuer": "https://updated-issuer.example.com", + "client_id": "updated-client-id" + }`, + expectedStatus: http.StatusNotFound, + expectedBody: false, + }, + { + name: "Update Identity Provider with Invalid JSON", + idpID: existingIDPID, + requestBody: `{invalid json`, + expectedStatus: http.StatusBadRequest, + expectedBody: false, + }, + } + + h := initIDPTestData(existingIDP) + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/api/identity-providers/%s", tc.idpID), bytes.NewBufferString(tc.requestBody)) + req.Header.Set("Content-Type", "application/json") + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ + UserId: testUserID, + AccountId: testAccountID, + }) + + router := mux.NewRouter() + router.HandleFunc("/api/identity-providers/{idpId}", h.updateIdentityProvider).Methods("PUT") + router.ServeHTTP(recorder, req) + + res := recorder.Result() + defer res.Body.Close() + + assert.Equal(t, tc.expectedStatus, recorder.Code) + + if tc.expectedBody { + content, err := io.ReadAll(res.Body) + require.NoError(t, err) + + var idp api.IdentityProvider + err = json.Unmarshal(content, &idp) + require.NoError(t, err) + assert.Equal(t, existingIDPID, *idp.Id) + assert.Equal(t, "Updated IDP", idp.Name) + } + }) + } +} + +func TestDeleteIdentityProvider(t *testing.T) { + existingIDP := &types.IdentityProvider{ + ID: existingIDPID, + Name: "Test IDP", + Type: types.IdentityProviderTypeOIDC, + Issuer: "https://issuer.example.com", + ClientID: "client-id", + } + + tt := []struct { + name string + idpID string + expectedStatus int + }{ + { + name: "Delete Existing Identity Provider", + idpID: existingIDPID, + expectedStatus: http.StatusOK, + }, + { + name: "Delete Non-Existing Identity Provider", + idpID: "non-existing-id", + expectedStatus: http.StatusNotFound, + }, + } + + h := initIDPTestData(existingIDP) + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("/api/identity-providers/%s", tc.idpID), nil) + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ + UserId: testUserID, + AccountId: testAccountID, + }) + + router := mux.NewRouter() + router.HandleFunc("/api/identity-providers/{idpId}", h.deleteIdentityProvider).Methods("DELETE") + router.ServeHTTP(recorder, req) + + res := recorder.Result() + defer res.Body.Close() + + assert.Equal(t, tc.expectedStatus, recorder.Code) + }) + } +} + +func TestToAPIResponse(t *testing.T) { + idp := &types.IdentityProvider{ + ID: "test-id", + Name: "Test IDP", + Type: types.IdentityProviderTypeGoogle, + Issuer: "https://accounts.google.com", + ClientID: "client-id", + ClientSecret: "should-not-be-returned", + } + + response := toAPIResponse(idp) + + assert.Equal(t, "test-id", *response.Id) + assert.Equal(t, "Test IDP", response.Name) + assert.Equal(t, api.IdentityProviderTypeGoogle, response.Type) + assert.Equal(t, "https://accounts.google.com", response.Issuer) + assert.Equal(t, "client-id", response.ClientId) + // Note: ClientSecret is not included in response type by design +} + +func TestFromAPIRequest(t *testing.T) { + req := &api.IdentityProviderRequest{ + Name: "New IDP", + Type: api.IdentityProviderTypeOkta, + Issuer: "https://dev-123456.okta.com", + ClientId: "okta-client-id", + ClientSecret: "okta-client-secret", + } + + idp := fromAPIRequest(req) + + assert.Equal(t, "New IDP", idp.Name) + assert.Equal(t, types.IdentityProviderTypeOkta, idp.Type) + assert.Equal(t, "https://dev-123456.okta.com", idp.Issuer) + assert.Equal(t, "okta-client-id", idp.ClientID) + assert.Equal(t, "okta-client-secret", idp.ClientSecret) +} diff --git a/management/server/http/handlers/instance/instance_handler.go b/management/server/http/handlers/instance/instance_handler.go new file mode 100644 index 000000000..889c3133e --- /dev/null +++ b/management/server/http/handlers/instance/instance_handler.go @@ -0,0 +1,67 @@ +package instance + +import ( + "encoding/json" + "net/http" + + "github.com/gorilla/mux" + log "github.com/sirupsen/logrus" + + nbinstance "github.com/netbirdio/netbird/management/server/instance" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" +) + +// handler handles the instance setup HTTP endpoints +type handler struct { + instanceManager nbinstance.Manager +} + +// AddEndpoints registers the instance setup endpoints. +// These endpoints bypass authentication for initial setup. +func AddEndpoints(instanceManager nbinstance.Manager, router *mux.Router) { + h := &handler{ + instanceManager: instanceManager, + } + + router.HandleFunc("/instance", h.getInstanceStatus).Methods("GET", "OPTIONS") + router.HandleFunc("/setup", h.setup).Methods("POST", "OPTIONS") +} + +// getInstanceStatus returns the instance status including whether setup is required. +// This endpoint is unauthenticated. +func (h *handler) getInstanceStatus(w http.ResponseWriter, r *http.Request) { + setupRequired, err := h.instanceManager.IsSetupRequired(r.Context()) + if err != nil { + log.WithContext(r.Context()).Errorf("failed to check setup status: %v", err) + util.WriteErrorResponse("failed to check instance status", http.StatusInternalServerError, w) + return + } + + util.WriteJSONObject(r.Context(), w, api.InstanceStatus{ + SetupRequired: setupRequired, + }) +} + +// setup creates the initial admin user for the instance. +// This endpoint is unauthenticated but only works when setup is required. +func (h *handler) setup(w http.ResponseWriter, r *http.Request) { + var req api.SetupRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + util.WriteErrorResponse("invalid request body", http.StatusBadRequest, w) + return + } + + userData, err := h.instanceManager.CreateOwnerUser(r.Context(), req.Email, req.Password, req.Name) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + log.WithContext(r.Context()).Infof("instance setup completed: created user %s", req.Email) + + util.WriteJSONObject(r.Context(), w, api.SetupResponse{ + UserId: userData.ID, + Email: userData.Email, + }) +} diff --git a/management/server/http/handlers/instance/instance_handler_test.go b/management/server/http/handlers/instance/instance_handler_test.go new file mode 100644 index 000000000..7a3a2bc88 --- /dev/null +++ b/management/server/http/handlers/instance/instance_handler_test.go @@ -0,0 +1,281 @@ +package instance + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "net/mail" + "testing" + + "github.com/gorilla/mux" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/server/idp" + nbinstance "github.com/netbirdio/netbird/management/server/instance" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/status" +) + +// mockInstanceManager implements instance.Manager for testing +type mockInstanceManager struct { + isSetupRequired bool + isSetupRequiredFn func(ctx context.Context) (bool, error) + createOwnerUserFn func(ctx context.Context, email, password, name string) (*idp.UserData, error) +} + +func (m *mockInstanceManager) IsSetupRequired(ctx context.Context) (bool, error) { + if m.isSetupRequiredFn != nil { + return m.isSetupRequiredFn(ctx) + } + return m.isSetupRequired, nil +} + +func (m *mockInstanceManager) CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error) { + if m.createOwnerUserFn != nil { + return m.createOwnerUserFn(ctx, email, password, name) + } + + // Default mock includes validation like the real manager + if !m.isSetupRequired { + return nil, status.Errorf(status.PreconditionFailed, "setup already completed") + } + if email == "" { + return nil, status.Errorf(status.InvalidArgument, "email is required") + } + if _, err := mail.ParseAddress(email); err != nil { + return nil, status.Errorf(status.InvalidArgument, "invalid email format") + } + if name == "" { + return nil, status.Errorf(status.InvalidArgument, "name is required") + } + if password == "" { + return nil, status.Errorf(status.InvalidArgument, "password is required") + } + if len(password) < 8 { + return nil, status.Errorf(status.InvalidArgument, "password must be at least 8 characters") + } + + return &idp.UserData{ + ID: "test-user-id", + Email: email, + Name: name, + }, nil +} + +var _ nbinstance.Manager = (*mockInstanceManager)(nil) + +func setupTestRouter(manager nbinstance.Manager) *mux.Router { + router := mux.NewRouter() + AddEndpoints(manager, router) + return router +} + +func TestGetInstanceStatus_SetupRequired(t *testing.T) { + manager := &mockInstanceManager{isSetupRequired: true} + router := setupTestRouter(manager) + + req := httptest.NewRequest(http.MethodGet, "/instance", nil) + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var response api.InstanceStatus + err := json.NewDecoder(rec.Body).Decode(&response) + require.NoError(t, err) + assert.True(t, response.SetupRequired) +} + +func TestGetInstanceStatus_SetupNotRequired(t *testing.T) { + manager := &mockInstanceManager{isSetupRequired: false} + router := setupTestRouter(manager) + + req := httptest.NewRequest(http.MethodGet, "/instance", nil) + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var response api.InstanceStatus + err := json.NewDecoder(rec.Body).Decode(&response) + require.NoError(t, err) + assert.False(t, response.SetupRequired) +} + +func TestGetInstanceStatus_Error(t *testing.T) { + manager := &mockInstanceManager{ + isSetupRequiredFn: func(ctx context.Context) (bool, error) { + return false, errors.New("database error") + }, + } + router := setupTestRouter(manager) + + req := httptest.NewRequest(http.MethodGet, "/instance", nil) + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusInternalServerError, rec.Code) +} + +func TestSetup_Success(t *testing.T) { + manager := &mockInstanceManager{ + isSetupRequired: true, + createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) { + assert.Equal(t, "admin@example.com", email) + assert.Equal(t, "securepassword123", password) + assert.Equal(t, "Admin User", name) + return &idp.UserData{ + ID: "created-user-id", + Email: email, + Name: name, + }, nil + }, + } + router := setupTestRouter(manager) + + body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin User"}` + req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var response api.SetupResponse + err := json.NewDecoder(rec.Body).Decode(&response) + require.NoError(t, err) + assert.Equal(t, "created-user-id", response.UserId) + assert.Equal(t, "admin@example.com", response.Email) +} + +func TestSetup_AlreadyCompleted(t *testing.T) { + manager := &mockInstanceManager{isSetupRequired: false} + router := setupTestRouter(manager) + + body := `{"email": "admin@example.com", "password": "securepassword123"}` + req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusPreconditionFailed, rec.Code) +} + +func TestSetup_MissingEmail(t *testing.T) { + manager := &mockInstanceManager{isSetupRequired: true} + router := setupTestRouter(manager) + + body := `{"password": "securepassword123"}` + req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusUnprocessableEntity, rec.Code) +} + +func TestSetup_InvalidEmail(t *testing.T) { + manager := &mockInstanceManager{isSetupRequired: true} + router := setupTestRouter(manager) + + body := `{"email": "not-an-email", "password": "securepassword123", "name": "User"}` + req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + // Note: Invalid email format uses mail.ParseAddress which is treated differently + // and returns 400 Bad Request instead of 422 Unprocessable Entity + assert.Equal(t, http.StatusUnprocessableEntity, rec.Code) +} + +func TestSetup_MissingPassword(t *testing.T) { + manager := &mockInstanceManager{isSetupRequired: true} + router := setupTestRouter(manager) + + body := `{"email": "admin@example.com", "name": "User"}` + req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusUnprocessableEntity, rec.Code) +} + +func TestSetup_PasswordTooShort(t *testing.T) { + manager := &mockInstanceManager{isSetupRequired: true} + router := setupTestRouter(manager) + + body := `{"email": "admin@example.com", "password": "short", "name": "User"}` + req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusUnprocessableEntity, rec.Code) +} + +func TestSetup_InvalidJSON(t *testing.T) { + manager := &mockInstanceManager{isSetupRequired: true} + router := setupTestRouter(manager) + + body := `{invalid json}` + req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestSetup_CreateUserError(t *testing.T) { + manager := &mockInstanceManager{ + isSetupRequired: true, + createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) { + return nil, errors.New("user creation failed") + }, + } + router := setupTestRouter(manager) + + body := `{"email": "admin@example.com", "password": "securepassword123", "name": "User"}` + req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusInternalServerError, rec.Code) +} + +func TestSetup_ManagerError(t *testing.T) { + manager := &mockInstanceManager{ + isSetupRequired: true, + createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) { + return nil, status.Errorf(status.Internal, "database error") + }, + } + router := setupTestRouter(manager) + + body := `{"email": "admin@example.com", "password": "securepassword123", "name": "User"}` + req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusInternalServerError, rec.Code) +} diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index f531f0cdb..a5c9ab0ac 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -299,7 +299,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { dnsDomain := h.networkMapController.GetDNSDomain(account.Settings) customZone := account.GetPeersCustomZone(r.Context(), dnsDomain) - netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) + netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers()) util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain)) } @@ -369,6 +369,9 @@ func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request) PortRanges: []types.RulePortRange{portRange}, }}, } + if protocol == types.PolicyRuleProtocolNetbirdSSH { + policy.Rules[0].AuthorizedUser = userAuth.UserId + } _, err = h.accountManager.SavePolicy(r.Context(), userAuth.AccountId, userAuth.UserId, policy, true) if err != nil { @@ -449,6 +452,18 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD SerialNumber: peer.Meta.SystemSerialNumber, InactivityExpirationEnabled: peer.InactivityExpirationEnabled, Ephemeral: peer.Ephemeral, + LocalFlags: &api.PeerLocalFlags{ + BlockInbound: &peer.Meta.Flags.BlockInbound, + BlockLanAccess: &peer.Meta.Flags.BlockLANAccess, + DisableClientRoutes: &peer.Meta.Flags.DisableClientRoutes, + DisableDns: &peer.Meta.Flags.DisableDNS, + DisableFirewall: &peer.Meta.Flags.DisableFirewall, + DisableServerRoutes: &peer.Meta.Flags.DisableServerRoutes, + LazyConnectionEnabled: &peer.Meta.Flags.LazyConnectionEnabled, + RosenpassEnabled: &peer.Meta.Flags.RosenpassEnabled, + RosenpassPermissive: &peer.Meta.Flags.RosenpassPermissive, + ServerSshAllowed: &peer.Meta.Flags.ServerSSHAllowed, + }, } if !approved { @@ -463,7 +478,6 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn if osVersion == "" { osVersion = peer.Meta.Core } - return &api.PeerBatch{ CreatedAt: peer.CreatedAt, Id: peer.ID, @@ -492,6 +506,18 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn SerialNumber: peer.Meta.SystemSerialNumber, InactivityExpirationEnabled: peer.InactivityExpirationEnabled, Ephemeral: peer.Ephemeral, + LocalFlags: &api.PeerLocalFlags{ + BlockInbound: &peer.Meta.Flags.BlockInbound, + BlockLanAccess: &peer.Meta.Flags.BlockLANAccess, + DisableClientRoutes: &peer.Meta.Flags.DisableClientRoutes, + DisableDns: &peer.Meta.Flags.DisableDNS, + DisableFirewall: &peer.Meta.Flags.DisableFirewall, + DisableServerRoutes: &peer.Meta.Flags.DisableServerRoutes, + LazyConnectionEnabled: &peer.Meta.Flags.LazyConnectionEnabled, + RosenpassEnabled: &peer.Meta.Flags.RosenpassEnabled, + RosenpassPermissive: &peer.Meta.Flags.RosenpassPermissive, + ServerSshAllowed: &peer.Meta.Flags.ServerSSHAllowed, + }, } } diff --git a/management/server/http/handlers/peers/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go index 55e779ff0..869a39b5e 100644 --- a/management/server/http/handlers/peers/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -66,7 +66,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler { }, } - srvUser := types.NewRegularUser(serviceUser) + srvUser := types.NewRegularUser(serviceUser, "", "") srvUser.IsServiceUser = true account := &types.Account{ @@ -75,7 +75,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler { Peers: peersMap, Users: map[string]*types.User{ adminUser: types.NewAdminUser(adminUser), - regularUser: types.NewRegularUser(regularUser), + regularUser: types.NewRegularUser(regularUser, "", ""), serviceUser: srvUser, }, Groups: map[string]*types.Group{ diff --git a/management/server/http/handlers/policies/policies_handler.go b/management/server/http/handlers/policies/policies_handler.go index ab1639ab1..e4d1d73df 100644 --- a/management/server/http/handlers/policies/policies_handler.go +++ b/management/server/http/handlers/policies/policies_handler.go @@ -221,6 +221,8 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s pr.Protocol = types.PolicyRuleProtocolUDP case api.PolicyRuleUpdateProtocolIcmp: pr.Protocol = types.PolicyRuleProtocolICMP + case api.PolicyRuleUpdateProtocolNetbirdSsh: + pr.Protocol = types.PolicyRuleProtocolNetbirdSSH default: util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown protocol type: %v", rule.Protocol), w) return @@ -254,6 +256,17 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s } } + if pr.Protocol == types.PolicyRuleProtocolNetbirdSSH && rule.AuthorizedGroups != nil && len(*rule.AuthorizedGroups) != 0 { + for _, sourceGroupID := range pr.Sources { + _, ok := (*rule.AuthorizedGroups)[sourceGroupID] + if !ok { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "authorized group for netbird-ssh protocol should be specified for each source group"), w) + return + } + } + pr.AuthorizedGroups = *rule.AuthorizedGroups + } + // validate policy object if pr.Protocol == types.PolicyRuleProtocolALL || pr.Protocol == types.PolicyRuleProtocolICMP { if len(pr.Ports) != 0 || len(pr.PortRanges) != 0 { @@ -380,6 +393,11 @@ func toPolicyResponse(groups []*types.Group, policy *types.Policy) *api.Policy { DestinationResource: r.DestinationResource.ToAPIResponse(), } + if len(r.AuthorizedGroups) != 0 { + authorizedGroupsCopy := r.AuthorizedGroups + rule.AuthorizedGroups = &authorizedGroupsCopy + } + if len(r.Ports) != 0 { portsCopy := r.Ports rule.Ports = &portsCopy diff --git a/management/server/http/handlers/users/users_handler.go b/management/server/http/handlers/users/users_handler.go index 4e03e5e9b..7669d7404 100644 --- a/management/server/http/handlers/users/users_handler.go +++ b/management/server/http/handlers/users/users_handler.go @@ -326,6 +326,16 @@ func toUserResponse(user *types.UserInfo, currenUserID string) *api.User { isCurrent := user.ID == currenUserID + var password *string + if user.Password != "" { + password = &user.Password + } + + var idpID *string + if user.IdPID != "" { + idpID = &user.IdPID + } + return &api.User{ Id: user.ID, Name: user.Name, @@ -339,6 +349,8 @@ func toUserResponse(user *types.UserInfo, currenUserID string) *api.User { LastLogin: &user.LastLogin, Issued: &user.Issued, PendingApproval: user.PendingApproval, + Password: password, + IdpId: idpID, } } diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 38cf0c290..966a6802a 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -134,6 +134,9 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts [] userAuth.IsChild = ok } + // Email is now extracted in ToUserAuth (from claims or userinfo endpoint) + // Available as userAuth.Email + // we need to call this method because if user is new, we will automatically add it to existing or create a new account accountId, _, err := m.ensureAccount(ctx, userAuth) if err != nil { diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index e8513feb5..656f72997 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -94,7 +94,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee groupsManagerMock := groups.NewManagerMock() peersManager := peers.NewManager(store, permissionsManager) - apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, networkMapController) + apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, networkMapController, nil) if err != nil { t.Fatalf("Failed to create API handler: %v", err) } diff --git a/management/server/identity_provider.go b/management/server/identity_provider.go new file mode 100644 index 000000000..6649c3953 --- /dev/null +++ b/management/server/identity_provider.go @@ -0,0 +1,234 @@ +package server + +import ( + "context" + "errors" + + "github.com/dexidp/dex/storage" + "github.com/rs/xid" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/idp/dex" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/status" +) + +// GetIdentityProviders returns all identity providers for an account +func (am *DefaultAccountManager) GetIdentityProviders(ctx context.Context, accountID, userID string) ([]*types.IdentityProvider, error) { + ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager) + if !ok { + log.Warn("identity provider management requires embedded IdP") + return []*types.IdentityProvider{}, nil + } + + connectors, err := embeddedManager.ListConnectors(ctx) + if err != nil { + return nil, status.Errorf(status.Internal, "failed to list identity providers: %v", err) + } + + result := make([]*types.IdentityProvider, 0, len(connectors)) + for _, conn := range connectors { + result = append(result, connectorConfigToIdentityProvider(conn, accountID)) + } + + return result, nil +} + +// GetIdentityProvider returns a specific identity provider by ID +func (am *DefaultAccountManager) GetIdentityProvider(ctx context.Context, accountID, idpID, userID string) (*types.IdentityProvider, error) { + ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager) + if !ok { + return nil, status.Errorf(status.Internal, "identity provider management requires embedded IdP") + } + + conn, err := embeddedManager.GetConnector(ctx, idpID) + if err != nil { + if errors.Is(err, storage.ErrNotFound) { + return nil, status.Errorf(status.NotFound, "identity provider not found") + } + return nil, status.Errorf(status.Internal, "failed to get identity provider: %v", err) + } + + return connectorConfigToIdentityProvider(conn, accountID), nil +} + +// CreateIdentityProvider creates a new identity provider +func (am *DefaultAccountManager) CreateIdentityProvider(ctx context.Context, accountID, userID string, idpConfig *types.IdentityProvider) (*types.IdentityProvider, error) { + ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Create) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + if err := idpConfig.Validate(); err != nil { + return nil, status.Errorf(status.InvalidArgument, "%s", err.Error()) + } + + embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager) + if !ok { + return nil, status.Errorf(status.Internal, "identity provider management requires embedded IdP") + } + + // Generate ID if not provided + if idpConfig.ID == "" { + idpConfig.ID = generateIdentityProviderID(idpConfig.Type) + } + idpConfig.AccountID = accountID + + connCfg := identityProviderToConnectorConfig(idpConfig) + + _, err = embeddedManager.CreateConnector(ctx, connCfg) + if err != nil { + return nil, status.Errorf(status.Internal, "failed to create identity provider: %v", err) + } + + am.StoreEvent(ctx, userID, idpConfig.ID, accountID, activity.IdentityProviderCreated, idpConfig.EventMeta()) + + return idpConfig, nil +} + +// UpdateIdentityProvider updates an existing identity provider +func (am *DefaultAccountManager) UpdateIdentityProvider(ctx context.Context, accountID, idpID, userID string, idpConfig *types.IdentityProvider) (*types.IdentityProvider, error) { + ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Update) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + if err := idpConfig.Validate(); err != nil { + return nil, status.Errorf(status.InvalidArgument, "%s", err.Error()) + } + + embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager) + if !ok { + return nil, status.Errorf(status.Internal, "identity provider management requires embedded IdP") + } + + idpConfig.ID = idpID + idpConfig.AccountID = accountID + + connCfg := identityProviderToConnectorConfig(idpConfig) + + if err := embeddedManager.UpdateConnector(ctx, connCfg); err != nil { + return nil, status.Errorf(status.Internal, "failed to update identity provider: %v", err) + } + + am.StoreEvent(ctx, userID, idpConfig.ID, accountID, activity.IdentityProviderUpdated, idpConfig.EventMeta()) + + return idpConfig, nil +} + +// DeleteIdentityProvider deletes an identity provider +func (am *DefaultAccountManager) DeleteIdentityProvider(ctx context.Context, accountID, idpID, userID string) error { + ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Delete) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !ok { + return status.NewPermissionDeniedError() + } + + embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager) + if !ok { + return status.Errorf(status.Internal, "identity provider management requires embedded IdP") + } + + // Get the IDP info before deleting for the activity event + conn, err := embeddedManager.GetConnector(ctx, idpID) + if err != nil { + if errors.Is(err, storage.ErrNotFound) { + return status.Errorf(status.NotFound, "identity provider not found") + } + return status.Errorf(status.Internal, "failed to get identity provider: %v", err) + } + idpConfig := connectorConfigToIdentityProvider(conn, accountID) + + if err := embeddedManager.DeleteConnector(ctx, idpID); err != nil { + if errors.Is(err, storage.ErrNotFound) { + return status.Errorf(status.NotFound, "identity provider not found") + } + return status.Errorf(status.Internal, "failed to delete identity provider: %v", err) + } + + am.StoreEvent(ctx, userID, idpID, accountID, activity.IdentityProviderDeleted, idpConfig.EventMeta()) + + return nil +} + +// connectorConfigToIdentityProvider converts a dex.ConnectorConfig to types.IdentityProvider +func connectorConfigToIdentityProvider(conn *dex.ConnectorConfig, accountID string) *types.IdentityProvider { + return &types.IdentityProvider{ + ID: conn.ID, + AccountID: accountID, + Type: types.IdentityProviderType(conn.Type), + Name: conn.Name, + Issuer: conn.Issuer, + ClientID: conn.ClientID, + ClientSecret: conn.ClientSecret, + } +} + +// identityProviderToConnectorConfig converts a types.IdentityProvider to dex.ConnectorConfig +func identityProviderToConnectorConfig(idpConfig *types.IdentityProvider) *dex.ConnectorConfig { + return &dex.ConnectorConfig{ + ID: idpConfig.ID, + Name: idpConfig.Name, + Type: string(idpConfig.Type), + Issuer: idpConfig.Issuer, + ClientID: idpConfig.ClientID, + ClientSecret: idpConfig.ClientSecret, + } +} + +// generateIdentityProviderID generates a unique ID for an identity provider. +// For specific provider types (okta, zitadel, entra, google, pocketid, microsoft), +// the ID is prefixed with the type name. Generic OIDC providers get no prefix. +func generateIdentityProviderID(idpType types.IdentityProviderType) string { + id := xid.New().String() + + switch idpType { + case types.IdentityProviderTypeOkta: + return "okta-" + id + case types.IdentityProviderTypeZitadel: + return "zitadel-" + id + case types.IdentityProviderTypeEntra: + return "entra-" + id + case types.IdentityProviderTypeGoogle: + return "google-" + id + case types.IdentityProviderTypePocketID: + return "pocketid-" + id + case types.IdentityProviderTypeMicrosoft: + return "microsoft-" + id + case types.IdentityProviderTypeAuthentik: + return "authentik-" + id + case types.IdentityProviderTypeKeycloak: + return "keycloak-" + id + default: + // Generic OIDC - no prefix + return id + } +} diff --git a/management/server/identity_provider_test.go b/management/server/identity_provider_test.go new file mode 100644 index 000000000..d637c4a8f --- /dev/null +++ b/management/server/identity_provider_test.go @@ -0,0 +1,202 @@ +package server + +import ( + "context" + "path/filepath" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" + "github.com/netbirdio/netbird/management/internals/server/config" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" +) + +func createManagerWithEmbeddedIdP(t testing.TB) (*DefaultAccountManager, *update_channel.PeersUpdateManager, error) { + t.Helper() + + ctx := context.Background() + + dataDir := t.TempDir() + testStore, cleanUp, err := store.NewTestStoreFromSQL(ctx, "", dataDir) + if err != nil { + return nil, nil, err + } + t.Cleanup(cleanUp) + + // Create embedded IdP manager + embeddedConfig := &idp.EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + Storage: idp.EmbeddedStorageConfig{ + Type: "sqlite3", + Config: idp.EmbeddedStorageTypeConfig{ + File: filepath.Join(dataDir, "dex.db"), + }, + }, + } + + idpManager, err := idp.NewEmbeddedIdPManager(ctx, embeddedConfig, nil) + if err != nil { + return nil, nil, err + } + t.Cleanup(func() { _ = idpManager.Stop(ctx) }) + + eventStore := &activity.InMemoryEventStore{} + + metrics, err := telemetry.NewDefaultAppMetrics(ctx) + if err != nil { + return nil, nil, err + } + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + settingsMockManager := settings.NewMockManager(ctrl) + settingsMockManager.EXPECT(). + GetExtraSettings(gomock.Any(), gomock.Any()). + Return(&types.ExtraSettings{}, nil). + AnyTimes() + settingsMockManager.EXPECT(). + UpdateExtraSettings(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(false, nil). + AnyTimes() + + permissionsManager := permissions.NewManager(testStore) + + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, testStore) + networkMapController := controller.NewController(ctx, testStore, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(testStore, peers.NewManager(testStore, permissionsManager)), &config.Config{}) + manager, err := BuildManager(ctx, &config.Config{}, testStore, networkMapController, idpManager, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + if err != nil { + return nil, nil, err + } + + return manager, updateManager, nil +} + +func TestDefaultAccountManager_CreateIdentityProvider_Validation(t *testing.T) { + manager, _, err := createManager(t) + require.NoError(t, err) + + userID := "testingUser" + account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID}) + require.NoError(t, err) + + testCases := []struct { + name string + idp *types.IdentityProvider + expectError bool + errorMsg string + }{ + { + name: "Missing Name", + idp: &types.IdentityProvider{ + Type: types.IdentityProviderTypeOIDC, + Issuer: "https://issuer.example.com", + ClientID: "client-id", + }, + expectError: true, + errorMsg: "name is required", + }, + { + name: "Missing Type", + idp: &types.IdentityProvider{ + Name: "Test IDP", + Issuer: "https://issuer.example.com", + ClientID: "client-id", + }, + expectError: true, + errorMsg: "type is required", + }, + { + name: "Missing Issuer", + idp: &types.IdentityProvider{ + Name: "Test IDP", + Type: types.IdentityProviderTypeOIDC, + ClientID: "client-id", + }, + expectError: true, + errorMsg: "issuer is required", + }, + { + name: "Missing ClientID", + idp: &types.IdentityProvider{ + Name: "Test IDP", + Type: types.IdentityProviderTypeOIDC, + Issuer: "https://issuer.example.com", + }, + expectError: true, + errorMsg: "client ID is required", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, err := manager.CreateIdentityProvider(context.Background(), account.Id, userID, tc.idp) + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.errorMsg) + } + }) + } +} + +func TestDefaultAccountManager_GetIdentityProviders(t *testing.T) { + manager, _, err := createManager(t) + require.NoError(t, err) + + userID := "testingUser" + account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID}) + require.NoError(t, err) + + // Should return empty list (stub implementation) + providers, err := manager.GetIdentityProviders(context.Background(), account.Id, userID) + require.NoError(t, err) + assert.Empty(t, providers) +} + +func TestDefaultAccountManager_GetIdentityProvider_NotFound(t *testing.T) { + manager, _, err := createManagerWithEmbeddedIdP(t) + require.NoError(t, err) + + userID := "testingUser" + account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID}) + require.NoError(t, err) + + // Should return not found error when identity provider doesn't exist + _, err = manager.GetIdentityProvider(context.Background(), account.Id, "any-id", userID) + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") +} + +func TestDefaultAccountManager_UpdateIdentityProvider_Validation(t *testing.T) { + manager, _, err := createManager(t) + require.NoError(t, err) + + userID := "testingUser" + account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID}) + require.NoError(t, err) + + // Should fail validation before reaching "not implemented" error + invalidIDP := &types.IdentityProvider{ + Name: "", // Empty name should fail validation + } + + _, err = manager.UpdateIdentityProvider(context.Background(), account.Id, "some-id", userID, invalidIDP) + require.Error(t, err) + assert.Contains(t, err.Error(), "name is required") +} diff --git a/management/server/idp/dex.go b/management/server/idp/dex.go new file mode 100644 index 000000000..0cac246e1 --- /dev/null +++ b/management/server/idp/dex.go @@ -0,0 +1,445 @@ +package idp + +import ( + "context" + "encoding/base64" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "github.com/dexidp/dex/api/v2" + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/bcrypt" + "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/credentials/insecure" + + "github.com/netbirdio/netbird/management/server/telemetry" +) + +// DexManager implements the Manager interface for Dex IDP. +// It uses Dex's gRPC API to manage users in the password database. +type DexManager struct { + grpcAddr string + httpClient ManagerHTTPClient + helper ManagerHelper + appMetrics telemetry.AppMetrics + mux sync.Mutex + conn *grpc.ClientConn +} + +// DexClientConfig Dex manager client configuration. +type DexClientConfig struct { + // GRPCAddr is the address of Dex's gRPC API (e.g., "localhost:5557") + GRPCAddr string + // Issuer is the Dex issuer URL (e.g., "https://dex.example.com/dex") + Issuer string +} + +// NewDexManager creates a new instance of DexManager. +func NewDexManager(config DexClientConfig, appMetrics telemetry.AppMetrics) (*DexManager, error) { + if config.GRPCAddr == "" { + return nil, fmt.Errorf("dex IdP configuration is incomplete, GRPCAddr is missing") + } + + httpTransport := http.DefaultTransport.(*http.Transport).Clone() + httpTransport.MaxIdleConns = 5 + + httpClient := &http.Client{ + Timeout: 10 * time.Second, + Transport: httpTransport, + } + helper := JsonParser{} + + return &DexManager{ + grpcAddr: config.GRPCAddr, + httpClient: httpClient, + helper: helper, + appMetrics: appMetrics, + }, nil +} + +// getConnection returns a gRPC connection to Dex, creating one if necessary. +// It also checks if an existing connection is still healthy and reconnects if needed. +func (dm *DexManager) getConnection(ctx context.Context) (*grpc.ClientConn, error) { + dm.mux.Lock() + defer dm.mux.Unlock() + + if dm.conn != nil { + state := dm.conn.GetState() + // If connection is shutdown or in a transient failure, close and reconnect + if state == connectivity.Shutdown || state == connectivity.TransientFailure { + log.WithContext(ctx).Debugf("Dex gRPC connection in state %s, reconnecting", state) + _ = dm.conn.Close() + dm.conn = nil + } else { + return dm.conn, nil + } + } + + log.WithContext(ctx).Debugf("connecting to Dex gRPC API at %s", dm.grpcAddr) + + conn, err := grpc.NewClient(dm.grpcAddr, + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + return nil, fmt.Errorf("failed to connect to Dex gRPC API: %w", err) + } + + dm.conn = conn + return conn, nil +} + +// getDexClient returns a Dex API client. +func (dm *DexManager) getDexClient(ctx context.Context) (api.DexClient, error) { + conn, err := dm.getConnection(ctx) + if err != nil { + return nil, err + } + return api.NewDexClient(conn), nil +} + +// encodeDexUserID encodes a user ID and connector ID into Dex's composite format. +// This is the reverse of parseDexUserID - it creates the base64-encoded protobuf +// format that Dex uses in JWT tokens. +func encodeDexUserID(userID, connectorID string) string { + // Build simple protobuf structure: + // Field 1 (tag 0x0a): user ID string + // Field 2 (tag 0x12): connector ID string + buf := make([]byte, 0, 2+len(userID)+2+len(connectorID)) + + // Field 1: user ID + buf = append(buf, 0x0a) // tag for field 1, wire type 2 (length-delimited) + buf = append(buf, byte(len(userID))) // length + buf = append(buf, []byte(userID)...) // value + + // Field 2: connector ID + buf = append(buf, 0x12) // tag for field 2, wire type 2 (length-delimited) + buf = append(buf, byte(len(connectorID))) // length + buf = append(buf, []byte(connectorID)...) // value + + return base64.StdEncoding.EncodeToString(buf) +} + +// parseDexUserID extracts the actual user ID from Dex's composite user ID. +// Dex encodes user IDs in JWT tokens as base64-encoded protobuf with format: +// - Field 1 (string): actual user ID +// - Field 2 (string): connector ID (e.g., "local") +// If the ID is not in this format, it returns the original ID. +func parseDexUserID(compositeID string) string { + // Try to decode as standard base64 + decoded, err := base64.StdEncoding.DecodeString(compositeID) + if err != nil { + // Try URL-safe base64 + decoded, err = base64.RawURLEncoding.DecodeString(compositeID) + if err != nil { + // Not base64 encoded, return as-is + return compositeID + } + } + + // Parse the simple protobuf structure + // Field 1 (tag 0x0a): user ID string + // Field 2 (tag 0x12): connector ID string + if len(decoded) < 2 { + return compositeID + } + + // Check for field 1 tag (0x0a = field 1, wire type 2/length-delimited) + if decoded[0] != 0x0a { + return compositeID + } + + // Read the length of the user ID string + length := int(decoded[1]) + if len(decoded) < 2+length { + return compositeID + } + + // Extract the user ID + userID := string(decoded[2 : 2+length]) + return userID +} + +// UpdateUserAppMetadata updates user app metadata based on userID and metadata map. +// Dex doesn't support app metadata, so this is a no-op. +func (dm *DexManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error { + return nil +} + +// GetUserDataByID requests user data from Dex via user ID. +func (dm *DexManager) GetUserDataByID(ctx context.Context, userID string, _ AppMetadata) (*UserData, error) { + if dm.appMetrics != nil { + dm.appMetrics.IDPMetrics().CountGetUserDataByID() + } + + client, err := dm.getDexClient(ctx) + if err != nil { + if dm.appMetrics != nil { + dm.appMetrics.IDPMetrics().CountRequestError() + } + return nil, err + } + + resp, err := client.ListPasswords(ctx, &api.ListPasswordReq{}) + if err != nil { + if dm.appMetrics != nil { + dm.appMetrics.IDPMetrics().CountRequestError() + } + return nil, fmt.Errorf("failed to list passwords from Dex: %w", err) + } + + // Try to parse the composite user ID from Dex JWT token + actualUserID := parseDexUserID(userID) + + for _, p := range resp.Passwords { + // Match against both the raw userID and the parsed actualUserID + if p.UserId == userID || p.UserId == actualUserID { + return &UserData{ + Email: p.Email, + Name: p.Username, + ID: userID, // Return the original ID for consistency + }, nil + } + } + + return nil, fmt.Errorf("user with ID %s not found", userID) +} + +// GetAccount returns all the users for a given account. +// Since Dex doesn't have account concepts, this returns all users. +func (dm *DexManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) { + if dm.appMetrics != nil { + dm.appMetrics.IDPMetrics().CountGetAccount() + } + + users, err := dm.getAllUsers(ctx) + if err != nil { + return nil, err + } + + // Set the account ID for all users + for _, user := range users { + user.AppMetadata.WTAccountID = accountID + } + + return users, nil +} + +// GetAllAccounts gets all registered accounts with corresponding user data. +// Since Dex doesn't have account concepts, all users are returned under UnsetAccountID. +func (dm *DexManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) { + if dm.appMetrics != nil { + dm.appMetrics.IDPMetrics().CountGetAllAccounts() + } + + users, err := dm.getAllUsers(ctx) + if err != nil { + return nil, err + } + + indexedUsers := make(map[string][]*UserData) + indexedUsers[UnsetAccountID] = users + + return indexedUsers, nil +} + +// CreateUser creates a new user in Dex's password database. +func (dm *DexManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) { + if dm.appMetrics != nil { + dm.appMetrics.IDPMetrics().CountCreateUser() + } + + client, err := dm.getDexClient(ctx) + if err != nil { + if dm.appMetrics != nil { + dm.appMetrics.IDPMetrics().CountRequestError() + } + return nil, err + } + + // Generate a random password for the new user + password := GeneratePassword(16, 2, 2, 2) + + // Hash the password using bcrypt + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return nil, fmt.Errorf("failed to hash password: %w", err) + } + + // Generate a user ID from email (Dex uses email as the key, but we need a stable ID) + userID := strings.ReplaceAll(email, "@", "-at-") + userID = strings.ReplaceAll(userID, ".", "-") + + req := &api.CreatePasswordReq{ + Password: &api.Password{ + Email: email, + Username: name, + UserId: userID, + Hash: hashedPassword, + }, + } + + resp, err := client.CreatePassword(ctx, req) + if err != nil { + if dm.appMetrics != nil { + dm.appMetrics.IDPMetrics().CountRequestError() + } + return nil, fmt.Errorf("failed to create user in Dex: %w", err) + } + + if resp.AlreadyExists { + return nil, fmt.Errorf("user with email %s already exists", email) + } + + log.WithContext(ctx).Debugf("created user %s in Dex", email) + + return &UserData{ + Email: email, + Name: name, + ID: userID, + AppMetadata: AppMetadata{ + WTAccountID: accountID, + WTInvitedBy: invitedByEmail, + }, + }, nil +} + +// GetUserByEmail searches users with a given email. +// If no users have been found, this function returns an empty list. +func (dm *DexManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) { + if dm.appMetrics != nil { + dm.appMetrics.IDPMetrics().CountGetUserByEmail() + } + + client, err := dm.getDexClient(ctx) + if err != nil { + if dm.appMetrics != nil { + dm.appMetrics.IDPMetrics().CountRequestError() + } + return nil, err + } + + resp, err := client.ListPasswords(ctx, &api.ListPasswordReq{}) + if err != nil { + if dm.appMetrics != nil { + dm.appMetrics.IDPMetrics().CountRequestError() + } + return nil, fmt.Errorf("failed to list passwords from Dex: %w", err) + } + + users := make([]*UserData, 0) + for _, p := range resp.Passwords { + if strings.EqualFold(p.Email, email) { + // Encode the user ID in Dex's composite format to match stored IDs + encodedID := encodeDexUserID(p.UserId, "local") + users = append(users, &UserData{ + Email: p.Email, + Name: p.Username, + ID: encodedID, + }) + } + } + + return users, nil +} + +// InviteUserByID resends an invitation to a user. +// Dex doesn't support invitations, so this returns an error. +func (dm *DexManager) InviteUserByID(_ context.Context, _ string) error { + return fmt.Errorf("method InviteUserByID not implemented for Dex") +} + +// DeleteUser deletes a user from Dex by user ID. +func (dm *DexManager) DeleteUser(ctx context.Context, userID string) error { + if dm.appMetrics != nil { + dm.appMetrics.IDPMetrics().CountDeleteUser() + } + + client, err := dm.getDexClient(ctx) + if err != nil { + if dm.appMetrics != nil { + dm.appMetrics.IDPMetrics().CountRequestError() + } + return err + } + + // First, find the user's email by ID + resp, err := client.ListPasswords(ctx, &api.ListPasswordReq{}) + if err != nil { + if dm.appMetrics != nil { + dm.appMetrics.IDPMetrics().CountRequestError() + } + return fmt.Errorf("failed to list passwords from Dex: %w", err) + } + + // Try to parse the composite user ID from Dex JWT token + actualUserID := parseDexUserID(userID) + + var email string + for _, p := range resp.Passwords { + if p.UserId == userID || p.UserId == actualUserID { + email = p.Email + break + } + } + + if email == "" { + return fmt.Errorf("user with ID %s not found", userID) + } + + // Delete the user by email + deleteResp, err := client.DeletePassword(ctx, &api.DeletePasswordReq{ + Email: email, + }) + if err != nil { + if dm.appMetrics != nil { + dm.appMetrics.IDPMetrics().CountRequestError() + } + return fmt.Errorf("failed to delete user from Dex: %w", err) + } + + if deleteResp.NotFound { + return fmt.Errorf("user with email %s not found", email) + } + + log.WithContext(ctx).Debugf("deleted user %s from Dex", email) + + return nil +} + +// getAllUsers retrieves all users from Dex's password database. +func (dm *DexManager) getAllUsers(ctx context.Context) ([]*UserData, error) { + client, err := dm.getDexClient(ctx) + if err != nil { + if dm.appMetrics != nil { + dm.appMetrics.IDPMetrics().CountRequestError() + } + return nil, err + } + + resp, err := client.ListPasswords(ctx, &api.ListPasswordReq{}) + if err != nil { + if dm.appMetrics != nil { + dm.appMetrics.IDPMetrics().CountRequestError() + } + return nil, fmt.Errorf("failed to list passwords from Dex: %w", err) + } + + users := make([]*UserData, 0, len(resp.Passwords)) + for _, p := range resp.Passwords { + // Encode the user ID in Dex's composite format (base64-encoded protobuf) + // to match how NetBird stores user IDs from Dex JWT tokens. + // The connector ID "local" is used for Dex's password database. + encodedID := encodeDexUserID(p.UserId, "local") + users = append(users, &UserData{ + Email: p.Email, + Name: p.Username, + ID: encodedID, + }) + } + + return users, nil +} diff --git a/management/server/idp/dex_test.go b/management/server/idp/dex_test.go new file mode 100644 index 000000000..b1991bd9f --- /dev/null +++ b/management/server/idp/dex_test.go @@ -0,0 +1,137 @@ +package idp + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/server/telemetry" +) + +func TestNewDexManager(t *testing.T) { + type test struct { + name string + inputConfig DexClientConfig + assertErrFunc require.ErrorAssertionFunc + assertErrFuncMessage string + } + + defaultTestConfig := DexClientConfig{ + GRPCAddr: "localhost:5557", + Issuer: "https://dex.example.com/dex", + } + + testCase1 := test{ + name: "Good Configuration", + inputConfig: defaultTestConfig, + assertErrFunc: require.NoError, + assertErrFuncMessage: "shouldn't return error", + } + + testCase2Config := defaultTestConfig + testCase2Config.GRPCAddr = "" + + testCase2 := test{ + name: "Missing GRPCAddr Configuration", + inputConfig: testCase2Config, + assertErrFunc: require.Error, + assertErrFuncMessage: "should return error when GRPCAddr is empty", + } + + // Test with empty issuer - should still work since issuer is optional for the manager + testCase3Config := defaultTestConfig + testCase3Config.Issuer = "" + + testCase3 := test{ + name: "Missing Issuer Configuration - OK", + inputConfig: testCase3Config, + assertErrFunc: require.NoError, + assertErrFuncMessage: "shouldn't return error when issuer is empty", + } + + for _, testCase := range []test{testCase1, testCase2, testCase3} { + t.Run(testCase.name, func(t *testing.T) { + manager, err := NewDexManager(testCase.inputConfig, &telemetry.MockAppMetrics{}) + testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage) + + if err == nil { + require.NotNil(t, manager, "manager should not be nil") + require.Equal(t, testCase.inputConfig.GRPCAddr, manager.grpcAddr, "grpcAddr should match") + } + }) + } +} + +func TestDexManagerUpdateUserAppMetadata(t *testing.T) { + config := DexClientConfig{ + GRPCAddr: "localhost:5557", + Issuer: "https://dex.example.com/dex", + } + + manager, err := NewDexManager(config, &telemetry.MockAppMetrics{}) + require.NoError(t, err, "should create manager without error") + + // UpdateUserAppMetadata should be a no-op for Dex + err = manager.UpdateUserAppMetadata(context.Background(), "test-user-id", AppMetadata{ + WTAccountID: "test-account", + }) + require.NoError(t, err, "UpdateUserAppMetadata should not return error") +} + +func TestDexManagerInviteUserByID(t *testing.T) { + config := DexClientConfig{ + GRPCAddr: "localhost:5557", + Issuer: "https://dex.example.com/dex", + } + + manager, err := NewDexManager(config, &telemetry.MockAppMetrics{}) + require.NoError(t, err, "should create manager without error") + + // InviteUserByID should return an error for Dex + err = manager.InviteUserByID(context.Background(), "test-user-id") + require.Error(t, err, "InviteUserByID should return error") + require.Contains(t, err.Error(), "not implemented", "error should mention not implemented") +} + +func TestParseDexUserID(t *testing.T) { + tests := []struct { + name string + compositeID string + expectedID string + }{ + { + name: "Parse base64-encoded protobuf composite ID", + // This is a real Dex composite ID: contains user ID "cf5db180-d360-484d-9b78-c5db92146420" and connector "local" + compositeID: "CiRjZjVkYjE4MC1kMzYwLTQ4NGQtOWI3OC1jNWRiOTIxNDY0MjASBWxvY2Fs", + expectedID: "cf5db180-d360-484d-9b78-c5db92146420", + }, + { + name: "Return plain ID unchanged", + compositeID: "simple-user-id", + expectedID: "simple-user-id", + }, + { + name: "Return UUID unchanged", + compositeID: "cf5db180-d360-484d-9b78-c5db92146420", + expectedID: "cf5db180-d360-484d-9b78-c5db92146420", + }, + { + name: "Handle empty string", + compositeID: "", + expectedID: "", + }, + { + name: "Handle invalid base64", + compositeID: "not-valid-base64!!!", + expectedID: "not-valid-base64!!!", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parseDexUserID(tt.compositeID) + require.Equal(t, tt.expectedID, result, "parsed user ID should match expected") + }) + } +} diff --git a/management/server/idp/embedded.go b/management/server/idp/embedded.go new file mode 100644 index 000000000..963b5ae3d --- /dev/null +++ b/management/server/idp/embedded.go @@ -0,0 +1,511 @@ +package idp + +import ( + "context" + "errors" + "fmt" + "net/http" + + "github.com/dexidp/dex/storage" + "github.com/google/uuid" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/idp/dex" + "github.com/netbirdio/netbird/management/server/telemetry" +) + +const ( + staticClientDashboard = "netbird-dashboard" + staticClientCLI = "netbird-cli" + defaultCLIRedirectURL1 = "http://localhost:53000/" + defaultCLIRedirectURL2 = "http://localhost:54000/" + defaultScopes = "openid profile email offline_access" + defaultUserIDClaim = "sub" +) + +// EmbeddedIdPConfig contains configuration for the embedded Dex OIDC identity provider +type EmbeddedIdPConfig struct { + // Enabled indicates whether the embedded IDP is enabled + Enabled bool + // Issuer is the OIDC issuer URL (e.g., "http://localhost:3002/oauth2") + Issuer string + // Storage configuration for the IdP database + Storage EmbeddedStorageConfig + // DashboardRedirectURIs are the OAuth2 redirect URIs for the dashboard client + DashboardRedirectURIs []string + // DashboardRedirectURIs are the OAuth2 redirect URIs for the dashboard client + CLIRedirectURIs []string + // Owner is the initial owner/admin user (optional, can be nil) + Owner *OwnerConfig + // SignKeyRefreshEnabled enables automatic key rotation for signing keys + SignKeyRefreshEnabled bool +} + +// EmbeddedStorageConfig holds storage configuration for the embedded IdP. +type EmbeddedStorageConfig struct { + // Type is the storage type (currently only "sqlite3" is supported) + Type string + // Config contains type-specific configuration + Config EmbeddedStorageTypeConfig +} + +// EmbeddedStorageTypeConfig contains type-specific storage configuration. +type EmbeddedStorageTypeConfig struct { + // File is the path to the SQLite database file (for sqlite3 type) + File string +} + +// OwnerConfig represents the initial owner/admin user for the embedded IdP. +type OwnerConfig struct { + // Email is the user's email address (required) + Email string + // Hash is the bcrypt hash of the user's password (required) + Hash string + // Username is the display name for the user (optional, defaults to email) + Username string +} + +// ToYAMLConfig converts EmbeddedIdPConfig to dex.YAMLConfig. +func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) { + if c.Issuer == "" { + return nil, fmt.Errorf("issuer is required") + } + if c.Storage.Type == "" { + c.Storage.Type = "sqlite3" + } + if c.Storage.Type == "sqlite3" && c.Storage.Config.File == "" { + return nil, fmt.Errorf("storage file is required for sqlite3") + } + + // Build CLI redirect URIs including the device callback (both relative and absolute) + cliRedirectURIs := c.CLIRedirectURIs + cliRedirectURIs = append(cliRedirectURIs, "/device/callback") + cliRedirectURIs = append(cliRedirectURIs, c.Issuer+"/device/callback") + + cfg := &dex.YAMLConfig{ + Issuer: c.Issuer, + Storage: dex.Storage{ + Type: c.Storage.Type, + Config: map[string]interface{}{ + "file": c.Storage.Config.File, + }, + }, + Web: dex.Web{ + AllowedOrigins: []string{"*"}, + AllowedHeaders: []string{"Authorization", "Content-Type"}, + }, + OAuth2: dex.OAuth2{ + SkipApprovalScreen: true, + }, + Frontend: dex.Frontend{ + Issuer: "NetBird", + Theme: "light", + }, + EnablePasswordDB: true, + StaticClients: []storage.Client{ + { + ID: staticClientDashboard, + Name: "NetBird Dashboard", + Public: true, + RedirectURIs: c.DashboardRedirectURIs, + }, + { + ID: staticClientCLI, + Name: "NetBird CLI", + Public: true, + RedirectURIs: cliRedirectURIs, + }, + }, + } + + // Add owner user if provided + if c.Owner != nil && c.Owner.Email != "" && c.Owner.Hash != "" { + username := c.Owner.Username + if username == "" { + username = c.Owner.Email + } + cfg.StaticPasswords = []dex.Password{ + { + Email: c.Owner.Email, + Hash: []byte(c.Owner.Hash), + Username: username, + UserID: uuid.New().String(), + }, + } + } + + return cfg, nil +} + +// Compile-time check that EmbeddedIdPManager implements Manager interface +var _ Manager = (*EmbeddedIdPManager)(nil) + +// Compile-time check that EmbeddedIdPManager implements OAuthConfigProvider interface +var _ OAuthConfigProvider = (*EmbeddedIdPManager)(nil) + +// OAuthConfigProvider defines the interface for OAuth configuration needed by auth flows. +type OAuthConfigProvider interface { + GetIssuer() string + GetKeysLocation() string + GetClientIDs() []string + GetUserIDClaim() string + GetTokenEndpoint() string + GetDeviceAuthEndpoint() string + GetAuthorizationEndpoint() string + GetDefaultScopes() string + GetCLIClientID() string + GetCLIRedirectURLs() []string +} + +// EmbeddedIdPManager implements the Manager interface using the embedded Dex IdP. +type EmbeddedIdPManager struct { + provider *dex.Provider + appMetrics telemetry.AppMetrics + config EmbeddedIdPConfig +} + +// NewEmbeddedIdPManager creates a new instance of EmbeddedIdPManager from a configuration. +// It instantiates the underlying Dex provider internally. +// Note: Storage defaults are applied in config loading (applyEmbeddedIdPConfig) based on Datadir. +func NewEmbeddedIdPManager(ctx context.Context, config *EmbeddedIdPConfig, appMetrics telemetry.AppMetrics) (*EmbeddedIdPManager, error) { + if config == nil { + return nil, fmt.Errorf("embedded IdP config is required") + } + + // Apply defaults for CLI redirect URIs + if len(config.CLIRedirectURIs) == 0 { + config.CLIRedirectURIs = []string{defaultCLIRedirectURL1, defaultCLIRedirectURL2} + } + + // there are some properties create when creating YAML config (e.g., auth clients) + yamlConfig, err := config.ToYAMLConfig() + if err != nil { + return nil, err + } + + provider, err := dex.NewProviderFromYAML(ctx, yamlConfig) + if err != nil { + return nil, fmt.Errorf("failed to create embedded IdP provider: %w", err) + } + + log.WithContext(ctx).Infof("embedded Dex IDP initialized with issuer: %s", yamlConfig.Issuer) + + return &EmbeddedIdPManager{ + provider: provider, + appMetrics: appMetrics, + config: *config, + }, nil +} + +// Handler returns the HTTP handler for serving OIDC requests. +func (m *EmbeddedIdPManager) Handler() http.Handler { + return m.provider.Handler() +} + +// Stop gracefully shuts down the embedded IdP provider. +func (m *EmbeddedIdPManager) Stop(ctx context.Context) error { + return m.provider.Stop(ctx) +} + +// UpdateUserAppMetadata updates user app metadata based on userID and metadata map. +func (m *EmbeddedIdPManager) UpdateUserAppMetadata(ctx context.Context, userID string, appMetadata AppMetadata) error { + // TODO: implement + return nil +} + +// GetUserDataByID requests user data from the embedded IdP via user ID. +func (m *EmbeddedIdPManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) { + user, err := m.provider.GetUserByID(ctx, userID) + if err != nil { + if m.appMetrics != nil { + m.appMetrics.IDPMetrics().CountRequestError() + } + return nil, fmt.Errorf("failed to get user by ID: %w", err) + } + + return &UserData{ + Email: user.Email, + Name: user.Username, + ID: user.UserID, + AppMetadata: appMetadata, + }, nil +} + +// GetAccount returns all the users for a given account. +// Note: Embedded dex doesn't store account metadata, so this returns all users. +func (m *EmbeddedIdPManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) { + users, err := m.provider.ListUsers(ctx) + if err != nil { + if m.appMetrics != nil { + m.appMetrics.IDPMetrics().CountRequestError() + } + return nil, fmt.Errorf("failed to list users: %w", err) + } + + result := make([]*UserData, 0, len(users)) + for _, user := range users { + result = append(result, &UserData{ + Email: user.Email, + Name: user.Username, + ID: user.UserID, + AppMetadata: AppMetadata{ + WTAccountID: accountID, + }, + }) + } + + return result, nil +} + +// GetAllAccounts gets all registered accounts with corresponding user data. +// Note: Embedded dex doesn't store account metadata, so all users are indexed under UnsetAccountID. +func (m *EmbeddedIdPManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) { + if m.appMetrics != nil { + m.appMetrics.IDPMetrics().CountGetAllAccounts() + } + + users, err := m.provider.ListUsers(ctx) + if err != nil { + if m.appMetrics != nil { + m.appMetrics.IDPMetrics().CountRequestError() + } + return nil, fmt.Errorf("failed to list users: %w", err) + } + + indexedUsers := make(map[string][]*UserData) + for _, user := range users { + indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], &UserData{ + Email: user.Email, + Name: user.Username, + ID: user.UserID, + }) + } + + return indexedUsers, nil +} + +// CreateUser creates a new user in the embedded IdP. +func (m *EmbeddedIdPManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) { + if m.appMetrics != nil { + m.appMetrics.IDPMetrics().CountCreateUser() + } + + // Check if user already exists + _, err := m.provider.GetUser(ctx, email) + if err == nil { + return nil, fmt.Errorf("user with email %s already exists", email) + } + if !errors.Is(err, storage.ErrNotFound) { + if m.appMetrics != nil { + m.appMetrics.IDPMetrics().CountRequestError() + } + return nil, fmt.Errorf("failed to check existing user: %w", err) + } + + // Generate a random password for the new user + password := GeneratePassword(16, 2, 2, 2) + + // Create the user via provider (handles hashing and ID generation) + // The provider returns an encoded user ID in Dex's format (base64 protobuf with connector ID) + userID, err := m.provider.CreateUser(ctx, email, name, password) + if err != nil { + if m.appMetrics != nil { + m.appMetrics.IDPMetrics().CountRequestError() + } + return nil, fmt.Errorf("failed to create user in embedded IdP: %w", err) + } + + log.WithContext(ctx).Debugf("created user %s in embedded IdP", email) + + return &UserData{ + Email: email, + Name: name, + ID: userID, + Password: password, + AppMetadata: AppMetadata{ + WTAccountID: accountID, + WTInvitedBy: invitedByEmail, + }, + }, nil +} + +// GetUserByEmail searches users with a given email. +func (m *EmbeddedIdPManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) { + user, err := m.provider.GetUser(ctx, email) + if err != nil { + if errors.Is(err, storage.ErrNotFound) { + return nil, nil // Return empty slice for not found + } + if m.appMetrics != nil { + m.appMetrics.IDPMetrics().CountRequestError() + } + return nil, fmt.Errorf("failed to get user by email: %w", err) + } + + return []*UserData{ + { + Email: user.Email, + Name: user.Username, + ID: user.UserID, + }, + }, nil +} + +// CreateUserWithPassword creates a new user in the embedded IdP with a provided password. +// Unlike CreateUser which auto-generates a password, this method uses the provided password. +// This is useful for instance setup where the user provides their own password. +func (m *EmbeddedIdPManager) CreateUserWithPassword(ctx context.Context, email, password, name string) (*UserData, error) { + if m.appMetrics != nil { + m.appMetrics.IDPMetrics().CountCreateUser() + } + + // Check if user already exists + _, err := m.provider.GetUser(ctx, email) + if err == nil { + return nil, fmt.Errorf("user with email %s already exists", email) + } + if !errors.Is(err, storage.ErrNotFound) { + if m.appMetrics != nil { + m.appMetrics.IDPMetrics().CountRequestError() + } + return nil, fmt.Errorf("failed to check existing user: %w", err) + } + + // Create the user via provider with the provided password + userID, err := m.provider.CreateUser(ctx, email, name, password) + if err != nil { + if m.appMetrics != nil { + m.appMetrics.IDPMetrics().CountRequestError() + } + return nil, fmt.Errorf("failed to create user in embedded IdP: %w", err) + } + + log.WithContext(ctx).Debugf("created user %s in embedded IdP with provided password", email) + + return &UserData{ + Email: email, + Name: name, + ID: userID, + }, nil +} + +// InviteUserByID resends an invitation to a user. +func (m *EmbeddedIdPManager) InviteUserByID(ctx context.Context, userID string) error { + // TODO: implement + return fmt.Errorf("not implemented") +} + +// DeleteUser deletes a user from the embedded IdP by user ID. +func (m *EmbeddedIdPManager) DeleteUser(ctx context.Context, userID string) error { + if m.appMetrics != nil { + m.appMetrics.IDPMetrics().CountDeleteUser() + } + + // Get user by ID to retrieve email (provider.DeleteUser requires email) + user, err := m.provider.GetUserByID(ctx, userID) + if err != nil { + if m.appMetrics != nil { + m.appMetrics.IDPMetrics().CountRequestError() + } + return fmt.Errorf("failed to get user for deletion: %w", err) + } + + err = m.provider.DeleteUser(ctx, user.Email) + if err != nil { + if m.appMetrics != nil { + m.appMetrics.IDPMetrics().CountRequestError() + } + return fmt.Errorf("failed to delete user from embedded IdP: %w", err) + } + + log.WithContext(ctx).Debugf("deleted user %s from embedded IdP", user.Email) + + return nil +} + +// CreateConnector creates a new identity provider connector in Dex. +// Returns the created connector config with the redirect URL populated. +func (m *EmbeddedIdPManager) CreateConnector(ctx context.Context, cfg *dex.ConnectorConfig) (*dex.ConnectorConfig, error) { + return m.provider.CreateConnector(ctx, cfg) +} + +// GetConnector retrieves an identity provider connector by ID. +func (m *EmbeddedIdPManager) GetConnector(ctx context.Context, id string) (*dex.ConnectorConfig, error) { + return m.provider.GetConnector(ctx, id) +} + +// ListConnectors returns all identity provider connectors. +func (m *EmbeddedIdPManager) ListConnectors(ctx context.Context) ([]*dex.ConnectorConfig, error) { + return m.provider.ListConnectors(ctx) +} + +// UpdateConnector updates an existing identity provider connector. +func (m *EmbeddedIdPManager) UpdateConnector(ctx context.Context, cfg *dex.ConnectorConfig) error { + // Preserve existing secret if not provided in update + if cfg.ClientSecret == "" { + existing, err := m.provider.GetConnector(ctx, cfg.ID) + if err != nil { + return fmt.Errorf("failed to get existing connector: %w", err) + } + cfg.ClientSecret = existing.ClientSecret + } + return m.provider.UpdateConnector(ctx, cfg) +} + +// DeleteConnector removes an identity provider connector. +func (m *EmbeddedIdPManager) DeleteConnector(ctx context.Context, id string) error { + return m.provider.DeleteConnector(ctx, id) +} + +// GetIssuer returns the OIDC issuer URL. +func (m *EmbeddedIdPManager) GetIssuer() string { + return m.provider.GetIssuer() +} + +// GetTokenEndpoint returns the OAuth2 token endpoint URL. +func (m *EmbeddedIdPManager) GetTokenEndpoint() string { + return m.provider.GetTokenEndpoint() +} + +// GetDeviceAuthEndpoint returns the OAuth2 device authorization endpoint URL. +func (m *EmbeddedIdPManager) GetDeviceAuthEndpoint() string { + return m.provider.GetDeviceAuthEndpoint() +} + +// GetAuthorizationEndpoint returns the OAuth2 authorization endpoint URL. +func (m *EmbeddedIdPManager) GetAuthorizationEndpoint() string { + return m.provider.GetAuthorizationEndpoint() +} + +// GetDefaultScopes returns the default OAuth2 scopes for authentication. +func (m *EmbeddedIdPManager) GetDefaultScopes() string { + return defaultScopes +} + +// GetCLIClientID returns the client ID for CLI authentication. +func (m *EmbeddedIdPManager) GetCLIClientID() string { + return staticClientCLI +} + +// GetCLIRedirectURLs returns the redirect URLs configured for the CLI client. +func (m *EmbeddedIdPManager) GetCLIRedirectURLs() []string { + if len(m.config.CLIRedirectURIs) == 0 { + return []string{defaultCLIRedirectURL1, defaultCLIRedirectURL2} + } + return m.config.CLIRedirectURIs +} + +// GetKeysLocation returns the JWKS endpoint URL for token validation. +func (m *EmbeddedIdPManager) GetKeysLocation() string { + return m.provider.GetKeysLocation() +} + +// GetClientIDs returns the OAuth2 client IDs configured for this provider. +func (m *EmbeddedIdPManager) GetClientIDs() []string { + return []string{staticClientDashboard, staticClientCLI} +} + +// GetUserIDClaim returns the JWT claim name used for user identification. +func (m *EmbeddedIdPManager) GetUserIDClaim() string { + return defaultUserIDClaim +} diff --git a/management/server/idp/embedded_test.go b/management/server/idp/embedded_test.go new file mode 100644 index 000000000..cfd9c2b54 --- /dev/null +++ b/management/server/idp/embedded_test.go @@ -0,0 +1,249 @@ +package idp + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/idp/dex" +) + +func TestEmbeddedIdPManager_CreateUser_EndToEnd(t *testing.T) { + ctx := context.Background() + + // Create a temporary directory for the test + tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + // Create the embedded IDP config + config := &EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + Storage: EmbeddedStorageConfig{ + Type: "sqlite3", + Config: EmbeddedStorageTypeConfig{ + File: filepath.Join(tmpDir, "dex.db"), + }, + }, + } + + // Create the embedded IDP manager + manager, err := NewEmbeddedIdPManager(ctx, config, nil) + require.NoError(t, err) + defer func() { _ = manager.Stop(ctx) }() + + // Test data + email := "newuser@example.com" + name := "New User" + accountID := "test-account-id" + invitedByEmail := "admin@example.com" + + // Create the user + userData, err := manager.CreateUser(ctx, email, name, accountID, invitedByEmail) + require.NoError(t, err) + require.NotNil(t, userData) + + t.Logf("Created user: ID=%s, Email=%s, Name=%s, Password=%s", + userData.ID, userData.Email, userData.Name, userData.Password) + + // Verify user data + assert.Equal(t, email, userData.Email) + assert.Equal(t, name, userData.Name) + assert.NotEmpty(t, userData.ID) + assert.NotEmpty(t, userData.Password) + assert.Equal(t, accountID, userData.AppMetadata.WTAccountID) + assert.Equal(t, invitedByEmail, userData.AppMetadata.WTInvitedBy) + + // Verify the user ID is in Dex's encoded format (base64 protobuf) + rawUserID, connectorID, err := dex.DecodeDexUserID(userData.ID) + require.NoError(t, err) + assert.NotEmpty(t, rawUserID) + assert.Equal(t, "local", connectorID) + + t.Logf("Decoded user ID: rawUserID=%s, connectorID=%s", rawUserID, connectorID) + + // Verify we can look up the user by the encoded ID + lookedUpUser, err := manager.GetUserDataByID(ctx, userData.ID, AppMetadata{WTAccountID: accountID}) + require.NoError(t, err) + assert.Equal(t, email, lookedUpUser.Email) + + // Verify we can look up by email + users, err := manager.GetUserByEmail(ctx, email) + require.NoError(t, err) + require.Len(t, users, 1) + assert.Equal(t, email, users[0].Email) + + // Verify creating duplicate user fails + _, err = manager.CreateUser(ctx, email, name, accountID, invitedByEmail) + assert.Error(t, err) + assert.Contains(t, err.Error(), "already exists") +} + +func TestEmbeddedIdPManager_GetUserDataByID_WithEncodedID(t *testing.T) { + ctx := context.Background() + + tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + config := &EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + Storage: EmbeddedStorageConfig{ + Type: "sqlite3", + Config: EmbeddedStorageTypeConfig{ + File: filepath.Join(tmpDir, "dex.db"), + }, + }, + } + + manager, err := NewEmbeddedIdPManager(ctx, config, nil) + require.NoError(t, err) + defer func() { _ = manager.Stop(ctx) }() + + // Create a user first + userData, err := manager.CreateUser(ctx, "test@example.com", "Test User", "account1", "admin@example.com") + require.NoError(t, err) + + // The returned ID should be encoded + encodedID := userData.ID + + // Lookup should work with the encoded ID + lookedUp, err := manager.GetUserDataByID(ctx, encodedID, AppMetadata{WTAccountID: "account1"}) + require.NoError(t, err) + assert.Equal(t, "test@example.com", lookedUp.Email) + assert.Equal(t, "Test User", lookedUp.Name) +} + +func TestEmbeddedIdPManager_DeleteUser(t *testing.T) { + ctx := context.Background() + + tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + config := &EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + Storage: EmbeddedStorageConfig{ + Type: "sqlite3", + Config: EmbeddedStorageTypeConfig{ + File: filepath.Join(tmpDir, "dex.db"), + }, + }, + } + + manager, err := NewEmbeddedIdPManager(ctx, config, nil) + require.NoError(t, err) + defer func() { _ = manager.Stop(ctx) }() + + // Create a user + userData, err := manager.CreateUser(ctx, "delete-me@example.com", "Delete Me", "account1", "admin@example.com") + require.NoError(t, err) + + // Delete the user using the encoded ID + err = manager.DeleteUser(ctx, userData.ID) + require.NoError(t, err) + + // Verify user no longer exists + _, err = manager.GetUserDataByID(ctx, userData.ID, AppMetadata{}) + assert.Error(t, err) +} + +func TestEmbeddedIdPManager_GetAccount(t *testing.T) { + ctx := context.Background() + + tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + config := &EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + Storage: EmbeddedStorageConfig{ + Type: "sqlite3", + Config: EmbeddedStorageTypeConfig{ + File: filepath.Join(tmpDir, "dex.db"), + }, + }, + } + + manager, err := NewEmbeddedIdPManager(ctx, config, nil) + require.NoError(t, err) + defer func() { _ = manager.Stop(ctx) }() + + // Create multiple users + _, err = manager.CreateUser(ctx, "user1@example.com", "User 1", "account1", "admin@example.com") + require.NoError(t, err) + + _, err = manager.CreateUser(ctx, "user2@example.com", "User 2", "account1", "admin@example.com") + require.NoError(t, err) + + // Get all users for the account + users, err := manager.GetAccount(ctx, "account1") + require.NoError(t, err) + assert.Len(t, users, 2) + + emails := make([]string, len(users)) + for i, u := range users { + emails[i] = u.Email + } + assert.Contains(t, emails, "user1@example.com") + assert.Contains(t, emails, "user2@example.com") +} + +func TestEmbeddedIdPManager_UserIDFormat_MatchesJWT(t *testing.T) { + // This test verifies that the user ID returned by CreateUser + // matches the format that Dex uses in JWT tokens (the 'sub' claim) + ctx := context.Background() + + tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + config := &EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + Storage: EmbeddedStorageConfig{ + Type: "sqlite3", + Config: EmbeddedStorageTypeConfig{ + File: filepath.Join(tmpDir, "dex.db"), + }, + }, + } + + manager, err := NewEmbeddedIdPManager(ctx, config, nil) + require.NoError(t, err) + defer func() { _ = manager.Stop(ctx) }() + + // Create a user + userData, err := manager.CreateUser(ctx, "jwt-test@example.com", "JWT Test", "account1", "admin@example.com") + require.NoError(t, err) + + // The ID should be in the format: base64(protobuf{user_id, connector_id}) + // Example: CiQ3YWFkOGMwNS0zMjg3LTQ3M2YtYjQyYS0zNjU1MDRiZjI1ZTcSBWxvY2Fs + + // Verify it can be decoded + rawUserID, connectorID, err := dex.DecodeDexUserID(userData.ID) + require.NoError(t, err) + + // Raw user ID should be a UUID + assert.Regexp(t, `^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`, rawUserID) + + // Connector ID should be "local" for password-based auth + assert.Equal(t, "local", connectorID) + + // Re-encoding should produce the same result + reEncoded := dex.EncodeDexUserID(rawUserID, connectorID) + assert.Equal(t, userData.ID, reEncoded) + + t.Logf("User ID format verified:") + t.Logf(" Encoded ID: %s", userData.ID) + t.Logf(" Raw UUID: %s", rawUserID) + t.Logf(" Connector: %s", connectorID) +} diff --git a/management/server/idp/idp.go b/management/server/idp/idp.go index f06e57196..28e3d81f9 100644 --- a/management/server/idp/idp.go +++ b/management/server/idp/idp.go @@ -72,6 +72,7 @@ type UserData struct { Name string `json:"name"` ID string `json:"user_id"` AppMetadata AppMetadata `json:"app_metadata"` + Password string `json:"-"` // Plain password, only set on user creation, excluded from JSON } func (u *UserData) MarshalBinary() (data []byte, err error) { @@ -173,40 +174,40 @@ func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetr return NewZitadelManager(*zitadelClientConfig, appMetrics) case "authentik": - authentikConfig := AuthentikClientConfig{ + return NewAuthentikManager(AuthentikClientConfig{ Issuer: config.ClientConfig.Issuer, ClientID: config.ClientConfig.ClientID, TokenEndpoint: config.ClientConfig.TokenEndpoint, GrantType: config.ClientConfig.GrantType, Username: config.ExtraConfig["Username"], Password: config.ExtraConfig["Password"], - } - return NewAuthentikManager(authentikConfig, appMetrics) + }, appMetrics) case "okta": - oktaClientConfig := OktaClientConfig{ + return NewOktaManager(OktaClientConfig{ Issuer: config.ClientConfig.Issuer, TokenEndpoint: config.ClientConfig.TokenEndpoint, GrantType: config.ClientConfig.GrantType, APIToken: config.ExtraConfig["ApiToken"], - } - return NewOktaManager(oktaClientConfig, appMetrics) + }, appMetrics) case "google": - googleClientConfig := GoogleWorkspaceClientConfig{ + return NewGoogleWorkspaceManager(ctx, GoogleWorkspaceClientConfig{ ServiceAccountKey: config.ExtraConfig["ServiceAccountKey"], CustomerID: config.ExtraConfig["CustomerId"], - } - return NewGoogleWorkspaceManager(ctx, googleClientConfig, appMetrics) + }, appMetrics) case "jumpcloud": - jumpcloudConfig := JumpCloudClientConfig{ + return NewJumpCloudManager(JumpCloudClientConfig{ APIToken: config.ExtraConfig["ApiToken"], - } - return NewJumpCloudManager(jumpcloudConfig, appMetrics) + }, appMetrics) case "pocketid": - pocketidConfig := PocketIdClientConfig{ + return NewPocketIdManager(PocketIdClientConfig{ APIToken: config.ExtraConfig["ApiToken"], ManagementEndpoint: config.ExtraConfig["ManagementEndpoint"], - } - return NewPocketIdManager(pocketidConfig, appMetrics) + }, appMetrics) + case "dex": + return NewDexManager(DexClientConfig{ + GRPCAddr: config.ExtraConfig["GRPCAddr"], + Issuer: config.ClientConfig.Issuer, + }, appMetrics) default: return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType) } diff --git a/management/server/instance/manager.go b/management/server/instance/manager.go new file mode 100644 index 000000000..6f50e3ff7 --- /dev/null +++ b/management/server/instance/manager.go @@ -0,0 +1,136 @@ +package instance + +import ( + "context" + "errors" + "fmt" + "net/mail" + "sync" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/shared/management/status" +) + +// Manager handles instance-level operations like initial setup. +type Manager interface { + // IsSetupRequired checks if instance setup is required. + // Returns true if embedded IDP is enabled and no accounts exist. + IsSetupRequired(ctx context.Context) (bool, error) + + // CreateOwnerUser creates the initial owner user in the embedded IDP. + // This should only be called when IsSetupRequired returns true. + CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error) +} + +// DefaultManager is the default implementation of Manager. +type DefaultManager struct { + store store.Store + embeddedIdpManager *idp.EmbeddedIdPManager + + setupRequired bool + setupMu sync.RWMutex +} + +// NewManager creates a new instance manager. +// If idpManager is not an EmbeddedIdPManager, setup-related operations will return appropriate defaults. +func NewManager(ctx context.Context, store store.Store, idpManager idp.Manager) (Manager, error) { + embeddedIdp, _ := idpManager.(*idp.EmbeddedIdPManager) + + m := &DefaultManager{ + store: store, + embeddedIdpManager: embeddedIdp, + setupRequired: false, + } + + if embeddedIdp != nil { + err := m.loadSetupRequired(ctx) + if err != nil { + return nil, err + } + } + + return m, nil +} + +func (m *DefaultManager) loadSetupRequired(ctx context.Context) error { + users, err := m.embeddedIdpManager.GetAllAccounts(ctx) + if err != nil { + return err + } + + m.setupMu.Lock() + m.setupRequired = len(users) == 0 + m.setupMu.Unlock() + + return nil +} + +// IsSetupRequired checks if instance setup is required. +// Setup is required when: +// 1. Embedded IDP is enabled +// 2. No accounts exist in the store +func (m *DefaultManager) IsSetupRequired(_ context.Context) (bool, error) { + if m.embeddedIdpManager == nil { + return false, nil + } + + m.setupMu.RLock() + defer m.setupMu.RUnlock() + + return m.setupRequired, nil +} + +// CreateOwnerUser creates the initial owner user in the embedded IDP. +func (m *DefaultManager) CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error) { + + if err := m.validateSetupInfo(email, password, name); err != nil { + return nil, err + } + + if m.embeddedIdpManager == nil { + return nil, errors.New("embedded IDP is not enabled") + } + + m.setupMu.RLock() + setupRequired := m.setupRequired + m.setupMu.RUnlock() + + if !setupRequired { + return nil, status.Errorf(status.PreconditionFailed, "setup already completed") + } + + userData, err := m.embeddedIdpManager.CreateUserWithPassword(ctx, email, password, name) + if err != nil { + return nil, fmt.Errorf("failed to create user in embedded IdP: %w", err) + } + + m.setupMu.Lock() + m.setupRequired = false + m.setupMu.Unlock() + + log.WithContext(ctx).Infof("created owner user %s in embedded IdP", email) + + return userData, nil +} + +func (m *DefaultManager) validateSetupInfo(email, password, name string) error { + if email == "" { + return status.Errorf(status.InvalidArgument, "email is required") + } + if _, err := mail.ParseAddress(email); err != nil { + return status.Errorf(status.InvalidArgument, "invalid email format") + } + if name == "" { + return status.Errorf(status.InvalidArgument, "name is required") + } + if password == "" { + return status.Errorf(status.InvalidArgument, "password is required") + } + if len(password) < 8 { + return status.Errorf(status.InvalidArgument, "password must be at least 8 characters") + } + return nil +} diff --git a/management/server/instance/manager_test.go b/management/server/instance/manager_test.go new file mode 100644 index 000000000..35d0ff53c --- /dev/null +++ b/management/server/instance/manager_test.go @@ -0,0 +1,268 @@ +package instance + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/server/idp" +) + +// mockStore implements a minimal store.Store for testing +type mockStore struct { + accountsCount int64 + err error +} + +func (m *mockStore) GetAccountsCounter(ctx context.Context) (int64, error) { + if m.err != nil { + return 0, m.err + } + return m.accountsCount, nil +} + +// mockEmbeddedIdPManager wraps the real EmbeddedIdPManager for testing +type mockEmbeddedIdPManager struct { + createUserFunc func(ctx context.Context, email, password, name string) (*idp.UserData, error) +} + +func (m *mockEmbeddedIdPManager) CreateUserWithPassword(ctx context.Context, email, password, name string) (*idp.UserData, error) { + if m.createUserFunc != nil { + return m.createUserFunc(ctx, email, password, name) + } + return &idp.UserData{ + ID: "test-user-id", + Email: email, + Name: name, + }, nil +} + +// testManager is a test implementation that accepts our mock types +type testManager struct { + store *mockStore + embeddedIdpManager *mockEmbeddedIdPManager +} + +func (m *testManager) IsSetupRequired(ctx context.Context) (bool, error) { + if m.embeddedIdpManager == nil { + return false, nil + } + + count, err := m.store.GetAccountsCounter(ctx) + if err != nil { + return false, err + } + + return count == 0, nil +} + +func (m *testManager) CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error) { + if m.embeddedIdpManager == nil { + return nil, errors.New("embedded IDP is not enabled") + } + + return m.embeddedIdpManager.CreateUserWithPassword(ctx, email, password, name) +} + +func TestIsSetupRequired_EmbeddedIdPDisabled(t *testing.T) { + manager := &testManager{ + store: &mockStore{accountsCount: 0}, + embeddedIdpManager: nil, // No embedded IDP + } + + required, err := manager.IsSetupRequired(context.Background()) + require.NoError(t, err) + assert.False(t, required, "setup should not be required when embedded IDP is disabled") +} + +func TestIsSetupRequired_NoAccounts(t *testing.T) { + manager := &testManager{ + store: &mockStore{accountsCount: 0}, + embeddedIdpManager: &mockEmbeddedIdPManager{}, + } + + required, err := manager.IsSetupRequired(context.Background()) + require.NoError(t, err) + assert.True(t, required, "setup should be required when no accounts exist") +} + +func TestIsSetupRequired_AccountsExist(t *testing.T) { + manager := &testManager{ + store: &mockStore{accountsCount: 1}, + embeddedIdpManager: &mockEmbeddedIdPManager{}, + } + + required, err := manager.IsSetupRequired(context.Background()) + require.NoError(t, err) + assert.False(t, required, "setup should not be required when accounts exist") +} + +func TestIsSetupRequired_MultipleAccounts(t *testing.T) { + manager := &testManager{ + store: &mockStore{accountsCount: 5}, + embeddedIdpManager: &mockEmbeddedIdPManager{}, + } + + required, err := manager.IsSetupRequired(context.Background()) + require.NoError(t, err) + assert.False(t, required, "setup should not be required when multiple accounts exist") +} + +func TestIsSetupRequired_StoreError(t *testing.T) { + manager := &testManager{ + store: &mockStore{err: errors.New("database error")}, + embeddedIdpManager: &mockEmbeddedIdPManager{}, + } + + _, err := manager.IsSetupRequired(context.Background()) + assert.Error(t, err, "should return error when store fails") +} + +func TestCreateOwnerUser_Success(t *testing.T) { + expectedEmail := "admin@example.com" + expectedName := "Admin User" + expectedPassword := "securepassword123" + + manager := &testManager{ + store: &mockStore{accountsCount: 0}, + embeddedIdpManager: &mockEmbeddedIdPManager{ + createUserFunc: func(ctx context.Context, email, password, name string) (*idp.UserData, error) { + assert.Equal(t, expectedEmail, email) + assert.Equal(t, expectedPassword, password) + assert.Equal(t, expectedName, name) + return &idp.UserData{ + ID: "created-user-id", + Email: email, + Name: name, + }, nil + }, + }, + } + + userData, err := manager.CreateOwnerUser(context.Background(), expectedEmail, expectedPassword, expectedName) + require.NoError(t, err) + assert.Equal(t, "created-user-id", userData.ID) + assert.Equal(t, expectedEmail, userData.Email) + assert.Equal(t, expectedName, userData.Name) +} + +func TestCreateOwnerUser_EmbeddedIdPDisabled(t *testing.T) { + manager := &testManager{ + store: &mockStore{accountsCount: 0}, + embeddedIdpManager: nil, + } + + _, err := manager.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin") + assert.Error(t, err, "should return error when embedded IDP is disabled") + assert.Contains(t, err.Error(), "embedded IDP is not enabled") +} + +func TestCreateOwnerUser_IdPError(t *testing.T) { + manager := &testManager{ + store: &mockStore{accountsCount: 0}, + embeddedIdpManager: &mockEmbeddedIdPManager{ + createUserFunc: func(ctx context.Context, email, password, name string) (*idp.UserData, error) { + return nil, errors.New("user already exists") + }, + }, + } + + _, err := manager.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin") + assert.Error(t, err, "should return error when IDP fails") +} + +func TestDefaultManager_ValidateSetupRequest(t *testing.T) { + manager := &DefaultManager{ + setupRequired: true, + } + + tests := []struct { + name string + email string + password string + userName string + expectError bool + errorMsg string + }{ + { + name: "valid request", + email: "admin@example.com", + password: "password123", + userName: "Admin User", + expectError: false, + }, + { + name: "empty email", + email: "", + password: "password123", + userName: "Admin User", + expectError: true, + errorMsg: "email is required", + }, + { + name: "invalid email format", + email: "not-an-email", + password: "password123", + userName: "Admin User", + expectError: true, + errorMsg: "invalid email format", + }, + { + name: "empty name", + email: "admin@example.com", + password: "password123", + userName: "", + expectError: true, + errorMsg: "name is required", + }, + { + name: "empty password", + email: "admin@example.com", + password: "", + userName: "Admin User", + expectError: true, + errorMsg: "password is required", + }, + { + name: "password too short", + email: "admin@example.com", + password: "short", + userName: "Admin User", + expectError: true, + errorMsg: "password must be at least 8 characters", + }, + { + name: "password exactly 8 characters", + email: "admin@example.com", + password: "12345678", + userName: "Admin User", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := manager.validateSetupInfo(tt.email, tt.password, tt.userName) + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestDefaultManager_CreateOwnerUser_SetupAlreadyCompleted(t *testing.T) { + manager := &DefaultManager{ + setupRequired: false, + embeddedIdpManager: &idp.EmbeddedIdPManager{}, + } + + _, err := manager.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin") + require.Error(t, err) + assert.Contains(t, err.Error(), "setup already completed") +} diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 42f192c0a..cc302400f 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -381,7 +381,7 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config return nil, nil, "", cleanup, err } - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController, nil) if err != nil { return nil, nil, "", cleanup, err } diff --git a/management/server/management_test.go b/management/server/management_test.go index 648201d4e..ace372509 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -242,6 +242,7 @@ func startServer( nil, server.MockIntegratedValidator{}, networkMapController, + nil, ) if err != nil { t.Fatalf("failed creating management server: %v", err) diff --git a/management/server/metrics/selfhosted.go b/management/server/metrics/selfhosted.go index 4ce57b1da..f7a344fcd 100644 --- a/management/server/metrics/selfhosted.go +++ b/management/server/metrics/selfhosted.go @@ -13,6 +13,7 @@ import ( "time" "github.com/hashicorp/go-version" + "github.com/netbirdio/netbird/idp/dex" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/types" @@ -28,6 +29,7 @@ const ( defaultPushInterval = 12 * time.Hour // requestTimeout http request timeout requestTimeout = 45 * time.Second + EmbeddedType = "embedded" ) type getTokenResponse struct { @@ -206,6 +208,8 @@ func (w *Worker) generateProperties(ctx context.Context) properties { peerActiveVersions []string osUIClients map[string]int rosenpassEnabled int + localUsers int + idpUsers int ) start := time.Now() metricsProperties := make(properties) @@ -266,6 +270,16 @@ func (w *Worker) generateProperties(ctx context.Context) properties { serviceUsers++ } else { users++ + if w.idpManager == EmbeddedType { + _, idpID, err := dex.DecodeDexUserID(user.Id) + if err == nil { + if idpID == "local" { + localUsers++ + } else { + idpUsers++ + } + } + } } pats += len(user.PATs) } @@ -353,6 +367,8 @@ func (w *Worker) generateProperties(ctx context.Context) properties { metricsProperties["idp_manager"] = w.idpManager metricsProperties["store_engine"] = w.dataSource.GetStoreEngine() metricsProperties["rosenpass_enabled"] = rosenpassEnabled + metricsProperties["local_users_count"] = localUsers + metricsProperties["idp_users_count"] = idpUsers for protocol, count := range rulesProtocol { metricsProperties["rules_protocol_"+protocol] = count diff --git a/management/server/metrics/selfhosted_test.go b/management/server/metrics/selfhosted_test.go index db0d90e64..d0ab45cd7 100644 --- a/management/server/metrics/selfhosted_test.go +++ b/management/server/metrics/selfhosted_test.go @@ -5,6 +5,7 @@ import ( "testing" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/idp/dex" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" @@ -25,6 +26,8 @@ func (mockDatasource) GetAllConnectedPeers() map[string]struct{} { // GetAllAccounts returns a list of *server.Account for use in tests with predefined information func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account { + localUserID := dex.EncodeDexUserID("10", "local") + idpUserID := dex.EncodeDexUserID("20", "zitadel") return []*types.Account{ { Id: "1", @@ -98,12 +101,14 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account { }, Users: map[string]*types.User{ "1": { + Id: "1", IsServiceUser: true, PATs: map[string]*types.PersonalAccessToken{ "1": {}, }, }, - "2": { + localUserID: { + Id: localUserID, IsServiceUser: false, PATs: map[string]*types.PersonalAccessToken{ "1": {}, @@ -162,12 +167,14 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account { }, Users: map[string]*types.User{ "1": { + Id: "1", IsServiceUser: true, PATs: map[string]*types.PersonalAccessToken{ "1": {}, }, }, - "2": { + idpUserID: { + Id: idpUserID, IsServiceUser: false, PATs: map[string]*types.PersonalAccessToken{ "1": {}, @@ -214,6 +221,7 @@ func TestGenerateProperties(t *testing.T) { worker := Worker{ dataSource: ds, connManager: ds, + idpManager: EmbeddedType, } properties := worker.generateProperties(context.Background()) @@ -327,4 +335,10 @@ func TestGenerateProperties(t *testing.T) { t.Errorf("expected 1 active_users_last_day, got %d", properties["active_users_last_day"]) } + if properties["local_users_count"] != 1 { + t.Errorf("expected 1 local_users_count, got %d", properties["local_users_count"]) + } + if properties["idp_users_count"] != 1 { + t.Errorf("expected 1 idp_users_count, got %d", properties["idp_users_count"]) + } } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 928098dbe..422829eba 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -2,11 +2,12 @@ package mock_server import ( "context" - "github.com/netbirdio/netbird/shared/auth" "net" "net/netip" "time" + "github.com/netbirdio/netbird/shared/auth" + "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -26,13 +27,13 @@ import ( var _ account.Manager = (*MockAccountManager)(nil) type MockAccountManager struct { - GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*types.Account, error) + GetOrCreateAccountByUserFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.Account, error) GetAccountFunc func(ctx context.Context, accountID string) (*types.Account, error) CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType types.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool, allowExtraDNSLabels bool) (*types.SetupKey, error) GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) AccountExistsFunc func(ctx context.Context, accountID string) (bool, error) - GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error) + GetAccountIDByUserIdFunc func(ctx context.Context, userAuth auth.UserAuth) (string, error) GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error) GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) @@ -128,6 +129,12 @@ type MockAccountManager struct { UpdateAccountPeersFunc func(ctx context.Context, accountID string) BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string) RecalculateNetworkMapCacheFunc func(ctx context.Context, accountId string) error + + GetIdentityProviderFunc func(ctx context.Context, accountID, idpID, userID string) (*types.IdentityProvider, error) + GetIdentityProvidersFunc func(ctx context.Context, accountID, userID string) ([]*types.IdentityProvider, error) + CreateIdentityProviderFunc func(ctx context.Context, accountID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error) + UpdateIdentityProviderFunc func(ctx context.Context, accountID, idpID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error) + DeleteIdentityProviderFunc func(ctx context.Context, accountID, idpID, userID string) error } func (am *MockAccountManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error { @@ -236,10 +243,10 @@ func (am *MockAccountManager) DeletePeer(ctx context.Context, accountID, peerID, // GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface func (am *MockAccountManager) GetOrCreateAccountByUser( - ctx context.Context, userId, domain string, + ctx context.Context, userAuth auth.UserAuth, ) (*types.Account, error) { if am.GetOrCreateAccountByUserFunc != nil { - return am.GetOrCreateAccountByUserFunc(ctx, userId, domain) + return am.GetOrCreateAccountByUserFunc(ctx, userAuth) } return nil, status.Errorf( codes.Unimplemented, @@ -275,9 +282,9 @@ func (am *MockAccountManager) AccountExists(ctx context.Context, accountID strin } // GetAccountIDByUserID mock implementation of GetAccountIDByUserID from server.AccountManager interface -func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId, domain string) (string, error) { +func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth auth.UserAuth) (string, error) { if am.GetAccountIDByUserIdFunc != nil { - return am.GetAccountIDByUserIdFunc(ctx, userId, domain) + return am.GetAccountIDByUserIdFunc(ctx, userAuth) } return "", status.Errorf( codes.Unimplemented, @@ -988,3 +995,47 @@ func (am *MockAccountManager) RecalculateNetworkMapCache(ctx context.Context, ac } return nil } + +func (am *MockAccountManager) GetUserIDByPeerKey(ctx context.Context, peerKey string) (string, error) { + return "something", nil +} + +// GetIdentityProvider mocks GetIdentityProvider of the AccountManager interface +func (am *MockAccountManager) GetIdentityProvider(ctx context.Context, accountID, idpID, userID string) (*types.IdentityProvider, error) { + if am.GetIdentityProviderFunc != nil { + return am.GetIdentityProviderFunc(ctx, accountID, idpID, userID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetIdentityProvider is not implemented") +} + +// GetIdentityProviders mocks GetIdentityProviders of the AccountManager interface +func (am *MockAccountManager) GetIdentityProviders(ctx context.Context, accountID, userID string) ([]*types.IdentityProvider, error) { + if am.GetIdentityProvidersFunc != nil { + return am.GetIdentityProvidersFunc(ctx, accountID, userID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetIdentityProviders is not implemented") +} + +// CreateIdentityProvider mocks CreateIdentityProvider of the AccountManager interface +func (am *MockAccountManager) CreateIdentityProvider(ctx context.Context, accountID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error) { + if am.CreateIdentityProviderFunc != nil { + return am.CreateIdentityProviderFunc(ctx, accountID, userID, idp) + } + return nil, status.Errorf(codes.Unimplemented, "method CreateIdentityProvider is not implemented") +} + +// UpdateIdentityProvider mocks UpdateIdentityProvider of the AccountManager interface +func (am *MockAccountManager) UpdateIdentityProvider(ctx context.Context, accountID, idpID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error) { + if am.UpdateIdentityProviderFunc != nil { + return am.UpdateIdentityProviderFunc(ctx, accountID, idpID, userID, idp) + } + return nil, status.Errorf(codes.Unimplemented, "method UpdateIdentityProvider is not implemented") +} + +// DeleteIdentityProvider mocks DeleteIdentityProvider of the AccountManager interface +func (am *MockAccountManager) DeleteIdentityProvider(ctx context.Context, accountID, idpID, userID string) error { + if am.DeleteIdentityProviderFunc != nil { + return am.DeleteIdentityProviderFunc(ctx, accountID, idpID, userID) + } + return status.Errorf(codes.Unimplemented, "method DeleteIdentityProvider is not implemented") +} diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index e3dd8b0b8..955c6b0ef 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -865,7 +865,7 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account, userID := testUserID domain := "example.com" - account := newAccountWithId(context.Background(), accountID, userID, domain, false) + account := newAccountWithId(context.Background(), accountID, userID, domain, "", "", false) account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup diff --git a/management/server/peer.go b/management/server/peer.go index 49f5bf2a5..977bd52af 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -91,7 +91,7 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc // fetch all the peers that have access to the user's peers for _, peer := range peers { - aclPeers, _ := account.GetPeerConnectionResources(ctx, peer, approvedPeersMap) + aclPeers, _, _, _ := account.GetPeerConnectionResources(ctx, peer, approvedPeersMap, account.GetActiveGroupUsers()) for _, p := range aclPeers { peersMap[p.ID] = p } @@ -269,6 +269,10 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user inactivityExpirationChanged = true } + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + return transaction.SavePeer(ctx, accountID, peer) }) if err != nil { @@ -340,6 +344,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer } var peer *nbpeer.Peer + var settings *types.Settings var eventsToStore []func() err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { @@ -348,11 +353,16 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return err } + settings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return err + } + if err = am.validatePeerDelete(ctx, transaction, accountID, peerID); err != nil { return err } - eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer}) + eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer}, settings) if err != nil { return fmt.Errorf("failed to delete peer: %w", err) } @@ -371,7 +381,11 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer storeEvent() } - if err := am.networkMapController.OnPeersDeleted(ctx, accountID, []string{peerID}); err != nil { + if err = am.integratedPeerValidator.PeerDeleted(ctx, accountID, peerID, settings.Extra); err != nil { + log.WithContext(ctx).Errorf("failed to delete peer %s from integrated validator: %v", peerID, err) + } + + if err = am.networkMapController.OnPeersDeleted(ctx, accountID, []string{peerID}); err != nil { log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peerID, err) } @@ -663,11 +677,10 @@ func getPeerIPDNSLabel(ip net.IP, peerHostName string) (string, error) { // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { var peer *nbpeer.Peer - var peerNotValid bool - var isStatusChanged bool var updated, versionChanged bool var err error var postureChecks []*posture.Checks + var peerGroupIDs []string settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { @@ -695,12 +708,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy return status.NewPeerLoginExpiredError() } - peerGroupIDs, err := getPeerGroupIDs(ctx, transaction, accountID, peer.ID) - if err != nil { - return err - } - - peerNotValid, isStatusChanged, err = am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra) + peerGroupIDs, err = getPeerGroupIDs(ctx, transaction, accountID, peer.ID) if err != nil { return err } @@ -724,6 +732,11 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy return nil, nil, nil, 0, err } + peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra) + if err != nil { + return nil, nil, nil, 0, err + } + if isStatusChanged || sync.UpdateAccountPeers || (updated && (len(postureChecks) > 0 || versionChanged)) { err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}) if err != nil { @@ -773,10 +786,9 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer var peer *nbpeer.Peer var updateRemotePeers bool - var isRequiresApproval bool - var isStatusChanged bool var isPeerUpdated bool var postureChecks []*posture.Checks + var peerGroupIDs []string settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { @@ -809,12 +821,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer } } - peerGroupIDs, err := getPeerGroupIDs(ctx, transaction, accountID, peer.ID) - if err != nil { - return err - } - - isRequiresApproval, isStatusChanged, err = am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra) + peerGroupIDs, err = getPeerGroupIDs(ctx, transaction, accountID, peer.ID) if err != nil { return err } @@ -852,6 +859,11 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer return nil, nil, nil, err } + isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra) + if err != nil { + return nil, nil, nil, err + } + if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) { err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}) if err != nil { @@ -1057,7 +1069,7 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun } for _, p := range userPeers { - aclPeers, _ := account.GetPeerConnectionResources(ctx, p, approvedPeersMap) + aclPeers, _, _, _ := account.GetPeerConnectionResources(ctx, p, approvedPeersMap, account.GetActiveGroupUsers()) for _, aclPeer := range aclPeers { if aclPeer.ID == peer.ID { return peer, nil @@ -1229,13 +1241,9 @@ func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID str // deletePeers deletes all specified peers and sends updates to the remote peers. // Returns a slice of functions to save events after successful peer deletion. -func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) { +func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer, settings *types.Settings) ([]func(), error) { var peerDeletedEvents []func() - settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) - if err != nil { - return nil, err - } dnsDomain := am.networkMapController.GetDNSDomain(settings) for _, peer := range peers { @@ -1243,10 +1251,6 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto return nil, fmt.Errorf("failed to remove peer %s from groups", peer.ID) } - if err := am.integratedPeerValidator.PeerDeleted(ctx, accountID, peer.ID, settings.Extra); err != nil { - return nil, err - } - peerPolicyRules, err := transaction.GetPolicyRulesByResourceID(ctx, store.LockingStrengthNone, accountID, peer.ID) if err != nil { return nil, err diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 752563299..ce04adf9e 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -502,7 +502,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { accountID := "test_account" adminUser := "account_creator" someUser := "some_user" - account := newAccountWithId(context.Background(), accountID, adminUser, "", false) + account := newAccountWithId(context.Background(), accountID, adminUser, "", "", "", false) account.Users[someUser] = &types.User{ Id: someUser, Role: types.UserRoleUser, @@ -689,7 +689,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { accountID := "test_account" adminUser := "account_creator" someUser := "some_user" - account := newAccountWithId(context.Background(), accountID, adminUser, "", false) + account := newAccountWithId(context.Background(), accountID, adminUser, "", "", "", false) account.Users[someUser] = &types.User{ Id: someUser, Role: testCase.role, @@ -759,7 +759,7 @@ func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccou adminUser := "account_creator" regularUser := "regular_user" - account := newAccountWithId(context.Background(), accountID, adminUser, "", false) + account := newAccountWithId(context.Background(), accountID, adminUser, "", "", "", false) account.Users[regularUser] = &types.User{ Id: regularUser, Role: types.UserRoleUser, @@ -2124,7 +2124,7 @@ func Test_DeletePeer(t *testing.T) { // account with an admin and a regular user accountID := "test_account" adminUser := "account_creator" - account := newAccountWithId(context.Background(), accountID, adminUser, "", false) + account := newAccountWithId(context.Background(), accountID, adminUser, "", "", "", false) account.Peers = map[string]*nbpeer.Peer{ "peer1": { ID: "peer1", @@ -2307,12 +2307,12 @@ func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) { } // Create account - account := newAccountWithId(context.Background(), "test-account", "owner", "", false) + account := newAccountWithId(context.Background(), "test-account", "owner", "", "", "", false) err = manager.Store.SaveAccount(context.Background(), account) require.NoError(t, err) // Create user pending approval - pendingUser := types.NewRegularUser("pending-user") + pendingUser := types.NewRegularUser("pending-user", "", "") pendingUser.AccountID = account.Id pendingUser.Blocked = true pendingUser.PendingApproval = true @@ -2344,12 +2344,12 @@ func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) { } // Create account - account := newAccountWithId(context.Background(), "test-account", "owner", "", false) + account := newAccountWithId(context.Background(), "test-account", "owner", "", "", "", false) err = manager.Store.SaveAccount(context.Background(), account) require.NoError(t, err) // Create regular user (not pending approval) - regularUser := types.NewRegularUser("regular-user") + regularUser := types.NewRegularUser("regular-user", "", "") regularUser.AccountID = account.Id err = manager.Store.SaveUser(context.Background(), regularUser) require.NoError(t, err) @@ -2378,12 +2378,12 @@ func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) { } // Create account - account := newAccountWithId(context.Background(), "test-account", "owner", "", false) + account := newAccountWithId(context.Background(), "test-account", "owner", "", "", "", false) err = manager.Store.SaveAccount(context.Background(), account) require.NoError(t, err) // Create user pending approval - pendingUser := types.NewRegularUser("pending-user") + pendingUser := types.NewRegularUser("pending-user", "", "") pendingUser.AccountID = account.Id pendingUser.Blocked = true pendingUser.PendingApproval = true @@ -2443,12 +2443,12 @@ func TestLoginPeer_ApprovedUserCanLogin(t *testing.T) { } // Create account - account := newAccountWithId(context.Background(), "test-account", "owner", "", false) + account := newAccountWithId(context.Background(), "test-account", "owner", "", "", "", false) err = manager.Store.SaveAccount(context.Background(), account) require.NoError(t, err) // Create regular user (not pending approval) - regularUser := types.NewRegularUser("regular-user") + regularUser := types.NewRegularUser("regular-user", "", "") regularUser.AccountID = account.Id err = manager.Store.SaveUser(context.Background(), regularUser) require.NoError(t, err) diff --git a/management/server/permissions/modules/module.go b/management/server/permissions/modules/module.go index 3d021a235..0ae10d521 100644 --- a/management/server/permissions/modules/module.go +++ b/management/server/permissions/modules/module.go @@ -3,33 +3,35 @@ package modules type Module string const ( - Networks Module = "networks" - Peers Module = "peers" - Groups Module = "groups" - Settings Module = "settings" - Accounts Module = "accounts" - Dns Module = "dns" - Nameservers Module = "nameservers" - Events Module = "events" - Policies Module = "policies" - Routes Module = "routes" - Users Module = "users" - SetupKeys Module = "setup_keys" - Pats Module = "pats" + Networks Module = "networks" + Peers Module = "peers" + Groups Module = "groups" + Settings Module = "settings" + Accounts Module = "accounts" + Dns Module = "dns" + Nameservers Module = "nameservers" + Events Module = "events" + Policies Module = "policies" + Routes Module = "routes" + Users Module = "users" + SetupKeys Module = "setup_keys" + Pats Module = "pats" + IdentityProviders Module = "identity_providers" ) var All = map[Module]struct{}{ - Networks: {}, - Peers: {}, - Groups: {}, - Settings: {}, - Accounts: {}, - Dns: {}, - Nameservers: {}, - Events: {}, - Policies: {}, - Routes: {}, - Users: {}, - SetupKeys: {}, - Pats: {}, + Networks: {}, + Peers: {}, + Groups: {}, + Settings: {}, + Accounts: {}, + Dns: {}, + Nameservers: {}, + Events: {}, + Policies: {}, + Routes: {}, + Users: {}, + SetupKeys: {}, + Pats: {}, + IdentityProviders: {}, } diff --git a/management/server/permissions/roles/network_admin.go b/management/server/permissions/roles/network_admin.go index e95d58381..8f69d46ad 100644 --- a/management/server/permissions/roles/network_admin.go +++ b/management/server/permissions/roles/network_admin.go @@ -93,5 +93,11 @@ var NetworkAdmin = RolePermissions{ operations.Update: false, operations.Delete: false, }, + modules.IdentityProviders: { + operations.Read: true, + operations.Create: false, + operations.Update: false, + operations.Delete: false, + }, }, } diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 90fe8f036..a3f987732 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -246,14 +246,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) { t.Run("check that all peers get map", func(t *testing.T) { for _, p := range account.Peers { - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p, validatedPeers) + peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), p, validatedPeers, account.GetActiveGroupUsers()) assert.GreaterOrEqual(t, len(peers), 1, "minimum number peers should present") assert.GreaterOrEqual(t, len(firewallRules), 1, "minimum number of firewall rules should present") } }) t.Run("check first peer map details", func(t *testing.T) { - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], validatedPeers) + peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], validatedPeers, account.GetActiveGroupUsers()) assert.Len(t, peers, 8) assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerC"]) @@ -509,7 +509,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) { }) t.Run("check port ranges support for older peers", func(t *testing.T) { - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerK"], validatedPeers) + peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerK"], validatedPeers, account.GetActiveGroupUsers()) assert.Len(t, peers, 1) assert.Contains(t, peers, account.Peers["peerI"]) @@ -635,7 +635,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { } t.Run("check first peer map", func(t *testing.T) { - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers) + peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers()) assert.Contains(t, peers, account.Peers["peerC"]) expectedFirewallRules := []*types.FirewallRule{ @@ -665,7 +665,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }) t.Run("check second peer map", func(t *testing.T) { - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers) + peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers()) assert.Contains(t, peers, account.Peers["peerB"]) expectedFirewallRules := []*types.FirewallRule{ @@ -697,7 +697,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { account.Policies[1].Rules[0].Bidirectional = false t.Run("check first peer map directional only", func(t *testing.T) { - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers) + peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers()) assert.Contains(t, peers, account.Peers["peerC"]) expectedFirewallRules := []*types.FirewallRule{ @@ -719,7 +719,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }) t.Run("check second peer map directional only", func(t *testing.T) { - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers) + peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers()) assert.Contains(t, peers, account.Peers["peerB"]) expectedFirewallRules := []*types.FirewallRule{ @@ -917,7 +917,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { t.Run("verify peer's network map with default group peer list", func(t *testing.T) { // peerB doesn't fulfill the NB posture check but is included in the destination group Swarm, // will establish a connection with all source peers satisfying the NB posture check. - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers) + peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers()) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -927,7 +927,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerC satisfy the NB posture check, should establish connection to all destination group peer's // We expect a single permissive firewall rule which all outgoing connections - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers) + peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers()) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, 7) expectedFirewallRules := []*types.FirewallRule{ @@ -992,7 +992,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers) + peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers, account.GetActiveGroupUsers()) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -1002,7 +1002,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers) + peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers, account.GetActiveGroupUsers()) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -1017,19 +1017,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's // no connection should be established to any peer of destination group - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers) + peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers()) assert.Len(t, peers, 0) assert.Len(t, firewallRules, 0) // peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's // no connection should be established to any peer of destination group - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers) + peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers, account.GetActiveGroupUsers()) assert.Len(t, peers, 0) assert.Len(t, firewallRules, 0) // peerC satisfy the NB posture check, should establish connection to all destination group peer's // We expect a single permissive firewall rule which all outgoing connections - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers) + peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers()) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers)) @@ -1044,14 +1044,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers) + peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers, account.GetActiveGroupUsers()) assert.Len(t, peers, 3) assert.Len(t, firewallRules, 3) assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerD"]) - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerA"], approvedPeers) + peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerA"], approvedPeers, account.GetActiveGroupUsers()) assert.Len(t, peers, 5) // assert peers from Group Swarm assert.Contains(t, peers, account.Peers["peerD"]) diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index 13152ed12..7f0a48dc7 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -109,7 +109,7 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, er ID: "peer1", } - account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain, false) + account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain, "", "", false) account.Users[admin.Id] = admin account.Users[user.Id] = user account.Peers["peer1"] = peer1 diff --git a/management/server/route_test.go b/management/server/route_test.go index a413d545b..6dc8c4cf4 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1320,7 +1320,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Accou accountID := "testingAcc" domain := "example.com" - account := newAccountWithId(context.Background(), accountID, userID, domain, false) + account := newAccountWithId(context.Background(), accountID, userID, domain, "", "", false) err := am.Store.SaveAccount(context.Background(), account) if err != nil { return nil, err diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index bc361bbd7..6eca27efd 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -15,6 +15,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" ) func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { @@ -24,7 +25,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { } userID := "testingUser" - account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") + account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID}) if err != nil { t.Fatal(err) } @@ -99,7 +100,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { } userID := "testingUser" - account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") + account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID}) if err != nil { t.Fatal(err) } @@ -204,7 +205,7 @@ func TestGetSetupKeys(t *testing.T) { } userID := "testingUser" - account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") + account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID}) if err != nil { t.Fatal(err) } @@ -471,7 +472,7 @@ func TestDefaultAccountManager_CreateSetupKey_ShouldNotAllowToUpdateRevokedKey(t } userID := "testingUser" - account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") + account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID}) if err != nil { t.Fatal(err) } diff --git a/management/server/store/file_store.go b/management/server/store/file_store.go index d5d9337ca..8db37ec30 100644 --- a/management/server/store/file_store.go +++ b/management/server/store/file_store.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/management/server/types" nbutil "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/util" + "github.com/netbirdio/netbird/util/crypt" ) // storeFileName Store file name. Stored in the datadir @@ -263,3 +264,8 @@ func (s *FileStore) Close(ctx context.Context) error { func (s *FileStore) GetStoreEngine() types.Engine { return types.FileStoreEngine } + +// SetFieldEncrypt is a no-op for FileStore as it doesn't support field encryption. +func (s *FileStore) SetFieldEncrypt(_ *crypt.FieldEncrypt) { + // no-op: FileStore stores data in plaintext JSON; encryption is not supported +} diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 2b8981b97..f407a35e6 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -37,6 +37,7 @@ import ( "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/status" + "github.com/netbirdio/netbird/util/crypt" ) const ( @@ -57,12 +58,14 @@ const ( // SqlStore represents an account storage backed by a Sql DB persisted to disk type SqlStore struct { - db *gorm.DB - globalAccountLock sync.Mutex - metrics telemetry.AppMetrics - installationPK int - storeEngine types.Engine - pool *pgxpool.Pool + db *gorm.DB + globalAccountLock sync.Mutex + metrics telemetry.AppMetrics + installationPK int + storeEngine types.Engine + pool *pgxpool.Pool + fieldEncrypt *crypt.FieldEncrypt + transactionTimeout time.Duration } type installation struct { @@ -84,6 +87,14 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met conns = runtime.NumCPU() } + transactionTimeout := 5 * time.Minute + if v := os.Getenv("NB_STORE_TRANSACTION_TIMEOUT"); v != "" { + if parsed, err := time.ParseDuration(v); err == nil { + transactionTimeout = parsed + } + } + log.WithContext(ctx).Infof("Setting transaction timeout to %v", transactionTimeout) + if storeEngine == types.SqliteStoreEngine { if err == nil { log.WithContext(ctx).Warnf("setting NB_SQL_MAX_OPEN_CONNS is not supported for sqlite, using default value 1") @@ -101,7 +112,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met if skipMigration { log.WithContext(ctx).Infof("skipping migration") - return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil + return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1, transactionTimeout: transactionTimeout}, nil } if err := migratePreAuto(ctx, db); err != nil { @@ -120,7 +131,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met return nil, fmt.Errorf("migratePostAuto: %w", err) } - return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil + return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1, transactionTimeout: transactionTimeout}, nil } func GetKeyQueryCondition(s *SqlStore) string { @@ -165,6 +176,13 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro generateAccountSQLTypes(account) + // Encrypt sensitive user data before saving + for i := range account.UsersG { + if err := account.UsersG[i].EncryptSensitiveData(s.fieldEncrypt); err != nil { + return fmt.Errorf("encrypt user: %w", err) + } + } + for _, group := range account.GroupsG { group.StoreGroupPeers() } @@ -430,7 +448,18 @@ func (s *SqlStore) SaveUsers(ctx context.Context, users []*types.User) error { return nil } - result := s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&users) + usersCopy := make([]*types.User, len(users)) + for i, user := range users { + userCopy := user.Copy() + userCopy.Email = user.Email + userCopy.Name = user.Name + if err := userCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil { + return fmt.Errorf("encrypt user: %w", err) + } + usersCopy[i] = userCopy + } + + result := s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&usersCopy) if result.Error != nil { log.WithContext(ctx).Errorf("failed to save users to store: %s", result.Error) return status.Errorf(status.Internal, "failed to save users to store") @@ -440,7 +469,15 @@ func (s *SqlStore) SaveUsers(ctx context.Context, users []*types.User) error { // SaveUser saves the given user to the database. func (s *SqlStore) SaveUser(ctx context.Context, user *types.User) error { - result := s.db.Save(user) + userCopy := user.Copy() + userCopy.Email = user.Email + userCopy.Name = user.Name + + if err := userCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil { + return fmt.Errorf("encrypt user: %w", err) + } + + result := s.db.Save(userCopy) if result.Error != nil { log.WithContext(ctx).Errorf("failed to save user to store: %s", result.Error) return status.Errorf(status.Internal, "failed to save user to store") @@ -590,6 +627,10 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren return nil, status.NewGetUserFromStoreError() } + if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil { + return nil, fmt.Errorf("decrypt user: %w", err) + } + return &user, nil } @@ -608,6 +649,10 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre return nil, status.NewGetUserFromStoreError() } + if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil { + return nil, fmt.Errorf("decrypt user: %w", err) + } + return &user, nil } @@ -644,6 +689,12 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre return nil, status.Errorf(status.Internal, "issue getting users from store") } + for _, user := range users { + if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil { + return nil, fmt.Errorf("decrypt user: %w", err) + } + } + return users, nil } @@ -662,6 +713,10 @@ func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStre return nil, status.Errorf(status.Internal, "failed to get account owner from the store") } + if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil { + return nil, fmt.Errorf("decrypt user: %w", err) + } + return &user, nil } @@ -856,6 +911,9 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types if user.AutoGroups == nil { user.AutoGroups = []string{} } + if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil { + return nil, fmt.Errorf("decrypt user: %w", err) + } account.Users[user.Id] = &user user.PATsG = nil } @@ -1131,6 +1189,9 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types. account.Users = make(map[string]*types.User, len(account.UsersG)) for i := range account.UsersG { user := &account.UsersG[i] + if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil { + return nil, fmt.Errorf("decrypt user: %w", err) + } user.PATs = make(map[string]*types.PersonalAccessToken) if userPats, ok := patsByUserID[user.Id]; ok { for j := range userPats { @@ -1535,7 +1596,7 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee } func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User, error) { - const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type FROM users WHERE account_id = $1` + const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type, email, name FROM users WHERE account_id = $1` rows, err := s.pool.Query(ctx, query, accountID) if err != nil { return nil, err @@ -1545,7 +1606,7 @@ func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User var autoGroups []byte var lastLogin, createdAt sql.NullTime var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool - err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &createdAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType) + err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &createdAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType, &u.Email, &u.Name) if err == nil { if lastLogin.Valid { u.LastLogin = &lastLogin.Time @@ -1910,16 +1971,17 @@ func (s *SqlStore) getPolicyRules(ctx context.Context, policyIDs []string) ([]*t if len(policyIDs) == 0 { return nil, nil } - const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges FROM policy_rules WHERE policy_id = ANY($1)` + const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges, authorized_groups, authorized_user FROM policy_rules WHERE policy_id = ANY($1)` rows, err := s.pool.Query(ctx, query, policyIDs) if err != nil { return nil, err } rules, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) { var r types.PolicyRule - var dest, destRes, sources, sourceRes, ports, portRanges []byte + var dest, destRes, sources, sourceRes, ports, portRanges, authorizedGroups []byte var enabled, bidirectional sql.NullBool - err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges) + var authorizedUser sql.NullString + err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges, &authorizedGroups, &authorizedUser) if err == nil { if enabled.Valid { r.Enabled = enabled.Bool @@ -1945,6 +2007,12 @@ func (s *SqlStore) getPolicyRules(ctx context.Context, policyIDs []string) ([]*t if portRanges != nil { _ = json.Unmarshal(portRanges, &r.PortRanges) } + if authorizedGroups != nil { + _ = json.Unmarshal(authorizedGroups, &r.AuthorizedGroups) + } + if authorizedUser.Valid { + r.AuthorizedUser = authorizedUser.String + } } return &r, err }) @@ -2890,8 +2958,11 @@ func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) } func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error { + timeoutCtx, cancel := context.WithTimeout(context.Background(), s.transactionTimeout) + defer cancel() + startTime := time.Now() - tx := s.db.Begin() + tx := s.db.WithContext(timeoutCtx).Begin() if tx.Error != nil { return tx.Error } @@ -2926,6 +2997,9 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor err := operation(repo) if err != nil { tx.Rollback() + if errors.Is(err, context.DeadlineExceeded) || errors.Is(timeoutCtx.Err(), context.DeadlineExceeded) { + log.WithContext(ctx).Warnf("transaction exceeded %s timeout after %v, stack: %s", s.transactionTimeout, time.Since(startTime), debug.Stack()) + } return err } @@ -2938,19 +3012,26 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor } err = tx.Commit().Error + if err != nil { + if errors.Is(err, context.DeadlineExceeded) || errors.Is(timeoutCtx.Err(), context.DeadlineExceeded) { + log.WithContext(ctx).Warnf("transaction commit exceeded %s timeout after %v, stack: %s", s.transactionTimeout, time.Since(startTime), debug.Stack()) + } + return err + } log.WithContext(ctx).Tracef("transaction took %v", time.Since(startTime)) if s.metrics != nil { s.metrics.StoreMetrics().CountTransactionDuration(time.Since(startTime)) } - return err + return nil } func (s *SqlStore) withTx(tx *gorm.DB) Store { return &SqlStore{ - db: tx, - storeEngine: s.storeEngine, + db: tx, + storeEngine: s.storeEngine, + fieldEncrypt: s.fieldEncrypt, } } @@ -2983,6 +3064,11 @@ func (s *SqlStore) GetDB() *gorm.DB { return s.db } +// SetFieldEncrypt sets the field encryptor for encrypting sensitive user data. +func (s *SqlStore) SetFieldEncrypt(enc *crypt.FieldEncrypt) { + s.fieldEncrypt = enc +} + func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.DNSSettings, error) { tx := s.db if lockStrength != LockingStrengthNone { @@ -4075,3 +4161,21 @@ func (s *SqlStore) GetPeersByGroupIDs(ctx context.Context, accountID string, gro return peers, nil } + +func (s *SqlStore) GetUserIDByPeerKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (string, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var userID string + result := tx.Model(&nbpeer.Peer{}). + Select("user_id"). + Take(&userID, GetKeyQueryCondition(s), peerKey) + + if result.Error != nil { + return "", status.Errorf(status.Internal, "failed to get user ID by peer key") + } + + return userID, nil +} diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 2e2623910..97aa81b12 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -32,6 +32,7 @@ import ( nbroute "github.com/netbirdio/netbird/route" route2 "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/status" + "github.com/netbirdio/netbird/util/crypt" ) func runTestForAllEngines(t *testing.T, testDataFile string, f func(t *testing.T, store Store)) { @@ -2090,7 +2091,7 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *ty setupKeys := map[string]*types.SetupKey{} nameServersGroups := make(map[string]*nbdns.NameServerGroup) - owner := types.NewOwnerUser(userID) + owner := types.NewOwnerUser(userID, "", "") owner.AccountID = accountID users[userID] = owner @@ -3114,6 +3115,138 @@ func TestSqlStore_SaveUsers(t *testing.T) { require.Equal(t, users[1].AutoGroups, user.AutoGroups) } +func TestSqlStore_SaveUserWithEncryption(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + // Enable encryption + key, err := crypt.GenerateKey() + require.NoError(t, err) + fieldEncrypt, err := crypt.NewFieldEncrypt(key) + require.NoError(t, err) + store.SetFieldEncrypt(fieldEncrypt) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + // rawUser is used to read raw (potentially encrypted) data from the database + // without any gorm hooks or automatic decryption + type rawUser struct { + Id string + Email string + Name string + } + + t.Run("save user with empty email and name", func(t *testing.T) { + user := &types.User{ + Id: "user-empty-fields", + AccountID: accountID, + Role: types.UserRoleUser, + Email: "", + Name: "", + AutoGroups: []string{"groupA"}, + } + err = store.SaveUser(context.Background(), user) + require.NoError(t, err) + + // Verify using direct database query that empty strings remain empty (not encrypted) + var raw rawUser + err = store.(*SqlStore).db.Table("users").Select("id, email, name").Where("id = ?", user.Id).First(&raw).Error + require.NoError(t, err) + require.Equal(t, "", raw.Email, "empty email should remain empty in database") + require.Equal(t, "", raw.Name, "empty name should remain empty in database") + + // Verify manual decryption returns empty strings + decryptedEmail, err := fieldEncrypt.Decrypt(raw.Email) + require.NoError(t, err) + require.Equal(t, "", decryptedEmail) + + decryptedName, err := fieldEncrypt.Decrypt(raw.Name) + require.NoError(t, err) + require.Equal(t, "", decryptedName) + }) + + t.Run("save user with email and name", func(t *testing.T) { + user := &types.User{ + Id: "user-with-fields", + AccountID: accountID, + Role: types.UserRoleAdmin, + Email: "test@example.com", + Name: "Test User", + AutoGroups: []string{"groupB"}, + } + err = store.SaveUser(context.Background(), user) + require.NoError(t, err) + + // Verify using direct database query that the data is encrypted (not plaintext) + var raw rawUser + err = store.(*SqlStore).db.Table("users").Select("id, email, name").Where("id = ?", user.Id).First(&raw).Error + require.NoError(t, err) + require.NotEqual(t, "test@example.com", raw.Email, "email should be encrypted in database") + require.NotEqual(t, "Test User", raw.Name, "name should be encrypted in database") + + // Verify manual decryption returns correct values + decryptedEmail, err := fieldEncrypt.Decrypt(raw.Email) + require.NoError(t, err) + require.Equal(t, "test@example.com", decryptedEmail) + + decryptedName, err := fieldEncrypt.Decrypt(raw.Name) + require.NoError(t, err) + require.Equal(t, "Test User", decryptedName) + }) + + t.Run("save multiple users with mixed fields", func(t *testing.T) { + users := []*types.User{ + { + Id: "batch-user-1", + AccountID: accountID, + Email: "", + Name: "", + }, + { + Id: "batch-user-2", + AccountID: accountID, + Email: "batch@example.com", + Name: "Batch User", + }, + } + err = store.SaveUsers(context.Background(), users) + require.NoError(t, err) + + // Verify first user (empty fields) using direct database query + var raw1 rawUser + err = store.(*SqlStore).db.Table("users").Select("id, email, name").Where("id = ?", "batch-user-1").First(&raw1).Error + require.NoError(t, err) + require.Equal(t, "", raw1.Email, "empty email should remain empty in database") + require.Equal(t, "", raw1.Name, "empty name should remain empty in database") + + // Verify second user (with fields) using direct database query + var raw2 rawUser + err = store.(*SqlStore).db.Table("users").Select("id, email, name").Where("id = ?", "batch-user-2").First(&raw2).Error + require.NoError(t, err) + require.NotEqual(t, "batch@example.com", raw2.Email, "email should be encrypted in database") + require.NotEqual(t, "Batch User", raw2.Name, "name should be encrypted in database") + + // Verify manual decryption returns empty strings for first user + decryptedEmail1, err := fieldEncrypt.Decrypt(raw1.Email) + require.NoError(t, err) + require.Equal(t, "", decryptedEmail1) + + decryptedName1, err := fieldEncrypt.Decrypt(raw1.Name) + require.NoError(t, err) + require.Equal(t, "", decryptedName1) + + // Verify manual decryption returns correct values for second user + decryptedEmail2, err := fieldEncrypt.Decrypt(raw2.Email) + require.NoError(t, err) + require.Equal(t, "batch@example.com", decryptedEmail2) + + decryptedName2, err := fieldEncrypt.Decrypt(raw2.Name) + require.NoError(t, err) + require.Equal(t, "Batch User", decryptedName2) + }) +} + func TestSqlStore_DeleteUser(t *testing.T) { store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) @@ -3718,6 +3851,69 @@ func TestSqlStore_GetPeersByGroupIDs(t *testing.T) { } } +func TestSqlStore_GetUserIDByPeerKey(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + userID := "test-user-123" + peerKey := "peer-key-abc" + + peer := &nbpeer.Peer{ + ID: "test-peer-1", + Key: peerKey, + AccountID: existingAccountID, + UserID: userID, + IP: net.IP{10, 0, 0, 1}, + DNSLabel: "test-peer-1", + } + + err = store.AddPeerToAccount(context.Background(), peer) + require.NoError(t, err) + + retrievedUserID, err := store.GetUserIDByPeerKey(context.Background(), LockingStrengthNone, peerKey) + require.NoError(t, err) + assert.Equal(t, userID, retrievedUserID) +} + +func TestSqlStore_GetUserIDByPeerKey_NotFound(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + nonExistentPeerKey := "non-existent-peer-key" + + userID, err := store.GetUserIDByPeerKey(context.Background(), LockingStrengthNone, nonExistentPeerKey) + require.Error(t, err) + assert.Equal(t, "", userID) +} + +func TestSqlStore_GetUserIDByPeerKey_NoUserID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + peerKey := "peer-key-abc" + + peer := &nbpeer.Peer{ + ID: "test-peer-1", + Key: peerKey, + AccountID: existingAccountID, + UserID: "", + IP: net.IP{10, 0, 0, 1}, + DNSLabel: "test-peer-1", + } + + err = store.AddPeerToAccount(context.Background(), peer) + require.NoError(t, err) + + retrievedUserID, err := store.GetUserIDByPeerKey(context.Background(), LockingStrengthNone, peerKey) + require.NoError(t, err) + assert.Equal(t, "", retrievedUserID) +} + func TestSqlStore_ApproveAccountPeers(t *testing.T) { runTestForAllEngines(t, "", func(t *testing.T, store Store) { accountID := "test-account" @@ -3794,3 +3990,30 @@ func TestSqlStore_ApproveAccountPeers(t *testing.T) { }) }) } + +func TestSqlStore_ExecuteInTransaction_Timeout(t *testing.T) { + if os.Getenv("NETBIRD_STORE_ENGINE") == "mysql" { + t.Skip("Skipping timeout test for MySQL") + } + + t.Setenv("NB_STORE_TRANSACTION_TIMEOUT", "1s") + + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + require.NoError(t, err) + t.Cleanup(cleanup) + + sqlStore, ok := store.(*SqlStore) + require.True(t, ok) + assert.Equal(t, 1*time.Second, sqlStore.transactionTimeout) + + ctx := context.Background() + err = sqlStore.ExecuteInTransaction(ctx, func(transaction Store) error { + // Sleep for 2 seconds to exceed the 1 second timeout + time.Sleep(2 * time.Second) + return nil + }) + + // The transaction should fail with an error (either timeout or already rolled back) + require.Error(t, err) + assert.Contains(t, err.Error(), "transaction has already been committed or rolled back", "expected transaction rolled back error, got: %v", err) +} diff --git a/management/server/store/store.go b/management/server/store/store.go index 0ec7949f9..013a66d73 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -27,6 +27,7 @@ import ( "github.com/netbirdio/netbird/management/server/testutil" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/util" + "github.com/netbirdio/netbird/util/crypt" "github.com/netbirdio/netbird/management/server/migration" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -204,6 +205,10 @@ type Store interface { MarkAccountPrimary(ctx context.Context, accountID string) error UpdateAccountNetwork(ctx context.Context, accountID string, ipNet net.IPNet) error GetPolicyRulesByResourceID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) ([]*types.PolicyRule, error) + + // SetFieldEncrypt sets the field encryptor for encrypting sensitive user data. + SetFieldEncrypt(enc *crypt.FieldEncrypt) + GetUserIDByPeerKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (string, error) } const ( @@ -339,6 +344,12 @@ func getMigrationsPreAuto(ctx context.Context) []migrationFunc { func(db *gorm.DB) error { return migration.DropIndex[routerTypes.NetworkRouter](ctx, db, "idx_network_routers_id") }, + func(db *gorm.DB) error { + return migration.MigrateNewField[types.User](ctx, db, "name", "") + }, + func(db *gorm.DB) error { + return migration.MigrateNewField[types.User](ctx, db, "email", "") + }, } } // migratePostAuto migrates the SQLite database to the latest schema func migratePostAuto(ctx context.Context, db *gorm.DB) error { diff --git a/management/server/types/account.go b/management/server/types/account.go index 692dc4541..3d341fd62 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -16,6 +16,7 @@ import ( "github.com/rs/xid" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/ssh/auth" nbdns "github.com/netbirdio/netbird/dns" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" @@ -45,8 +46,10 @@ const ( // nativeSSHPortString defines the default port number as a string used for native SSH connections; this port is used by clients when hijacking ssh connections. nativeSSHPortString = "22022" + nativeSSHPortNumber = 22022 // defaultSSHPortString defines the standard SSH port number as a string, commonly used for default SSH connections. defaultSSHPortString = "22" + defaultSSHPortNumber = 22 ) type supportedFeatures struct { @@ -275,6 +278,7 @@ func (a *Account) GetPeerNetworkMap( resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter, metrics *telemetry.AccountManagerMetrics, + groupIDToUserIDs map[string][]string, ) *NetworkMap { start := time.Now() peer := a.Peers[peerID] @@ -290,7 +294,7 @@ func (a *Account) GetPeerNetworkMap( } } - aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap) + aclPeers, firewallRules, authorizedUsers, enableSSH := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap, groupIDToUserIDs) // exclude expired peers var peersToConnect []*nbpeer.Peer var expiredPeers []*nbpeer.Peer @@ -338,6 +342,8 @@ func (a *Account) GetPeerNetworkMap( OfflinePeers: expiredPeers, FirewallRules: firewallRules, RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules), + AuthorizedUsers: authorizedUsers, + EnableSSH: enableSSH, } if metrics != nil { @@ -1122,8 +1128,10 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map // GetPeerConnectionResources for a given peer // // This function returns the list of peers and firewall rules that are applicable to a given peer. -func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) { +func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}, groupIDToUserIDs map[string][]string) ([]*nbpeer.Peer, []*FirewallRule, map[string]map[string]struct{}, bool) { generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx, peer) + authorizedUsers := make(map[string]map[string]struct{}) // machine user to list of userIDs + sshEnabled := false for _, policy := range a.Policies { if !policy.Enabled { @@ -1166,10 +1174,58 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.P if peerInDestinations { generateResources(rule, sourcePeers, FirewallRuleDirectionIN) } + + if peerInDestinations && rule.Protocol == PolicyRuleProtocolNetbirdSSH { + sshEnabled = true + switch { + case len(rule.AuthorizedGroups) > 0: + for groupID, localUsers := range rule.AuthorizedGroups { + userIDs, ok := groupIDToUserIDs[groupID] + if !ok { + log.WithContext(ctx).Tracef("no user IDs found for group ID %s", groupID) + continue + } + + if len(localUsers) == 0 { + localUsers = []string{auth.Wildcard} + } + + for _, localUser := range localUsers { + if authorizedUsers[localUser] == nil { + authorizedUsers[localUser] = make(map[string]struct{}) + } + for _, userID := range userIDs { + authorizedUsers[localUser][userID] = struct{}{} + } + } + } + case rule.AuthorizedUser != "": + if authorizedUsers[auth.Wildcard] == nil { + authorizedUsers[auth.Wildcard] = make(map[string]struct{}) + } + authorizedUsers[auth.Wildcard][rule.AuthorizedUser] = struct{}{} + default: + authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs() + } + } else if peerInDestinations && policyRuleImpliesLegacySSH(rule) && peer.SSHEnabled { + sshEnabled = true + authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs() + } } } - return getAccumulatedResources() + peers, fwRules := getAccumulatedResources() + return peers, fwRules, authorizedUsers, sshEnabled +} + +func (a *Account) getAllowedUserIDs() map[string]struct{} { + users := make(map[string]struct{}) + for _, nbUser := range a.Users { + if !nbUser.IsBlocked() && !nbUser.IsServiceUser { + users[nbUser.Id] = struct{}{} + } + } + return users } // connResourcesGenerator returns generator and accumulator function which returns the result of generator calls @@ -1194,12 +1250,17 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer peersExists[peer.ID] = struct{}{} } + protocol := rule.Protocol + if protocol == PolicyRuleProtocolNetbirdSSH { + protocol = PolicyRuleProtocolTCP + } + fr := FirewallRule{ PolicyID: rule.ID, PeerIP: peer.IP.String(), Direction: direction, Action: string(rule.Action), - Protocol: string(rule.Protocol), + Protocol: string(protocol), } ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) + @@ -1221,6 +1282,28 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer } } +func policyRuleImpliesLegacySSH(rule *PolicyRule) bool { + return rule.Protocol == PolicyRuleProtocolALL || (rule.Protocol == PolicyRuleProtocolTCP && (portsIncludesSSH(rule.Ports) || portRangeIncludesSSH(rule.PortRanges))) +} + +func portRangeIncludesSSH(portRanges []RulePortRange) bool { + for _, pr := range portRanges { + if (pr.Start <= defaultSSHPortNumber && pr.End >= defaultSSHPortNumber) || (pr.Start <= nativeSSHPortNumber && pr.End >= nativeSSHPortNumber) { + return true + } + } + return false +} + +func portsIncludesSSH(ports []string) bool { + for _, port := range ports { + if port == defaultSSHPortString || port == nativeSSHPortString { + return true + } + } + return false +} + // getAllPeersFromGroups for given peer ID and list of groups // // Returns a list of peers from specified groups that pass specified posture checks @@ -1265,7 +1348,11 @@ func (a *Account) getPeerFromResource(resource Resource, peerID string) ([]*nbpe return []*nbpeer.Peer{}, false } - return []*nbpeer.Peer{peer}, resource.ID == peerID + if peer.ID == peerID { + return []*nbpeer.Peer{}, true + } + + return []*nbpeer.Peer{peer}, false } // validatePostureChecksOnPeer validates the posture checks on a peer @@ -1773,6 +1860,26 @@ func (a *Account) AddAllGroup(disableDefaultPolicy bool) error { return nil } +func (a *Account) GetActiveGroupUsers() map[string][]string { + allGroupID := "" + group, err := a.GetGroupAll() + if err != nil { + log.Errorf("failed to get group all: %v", err) + } else { + allGroupID = group.ID + } + groups := make(map[string][]string, len(a.GroupsG)) + for _, user := range a.Users { + if !user.IsBlocked() && !user.IsServiceUser { + for _, groupID := range user.AutoGroups { + groups[groupID] = append(groups[groupID], user.Id) + } + groups[allGroupID] = append(groups[allGroupID], user.Id) + } + } + return groups +} + // expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule { features := peerSupportedFirewallFeatures(peer.Meta.WtVersion) @@ -1804,7 +1911,7 @@ func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer expanded = append(expanded, &fr) } - if shouldCheckRulesForNativeSSH(features.nativeSSH, rule, peer) { + if shouldCheckRulesForNativeSSH(features.nativeSSH, rule, peer) || rule.Protocol == PolicyRuleProtocolNetbirdSSH { expanded = addNativeSSHRule(base, expanded) } diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index f9aa6a1c2..2c9f2428d 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -1105,6 +1105,193 @@ func Test_ExpandPortsAndRanges_SSHRuleExpansion(t *testing.T) { } } +func Test_GetActiveGroupUsers(t *testing.T) { + tests := []struct { + name string + account *Account + expected map[string][]string + }{ + { + name: "all users are active", + account: &Account{ + Users: map[string]*User{ + "user1": { + Id: "user1", + AutoGroups: []string{"group1", "group2"}, + Blocked: false, + }, + "user2": { + Id: "user2", + AutoGroups: []string{"group2", "group3"}, + Blocked: false, + }, + "user3": { + Id: "user3", + AutoGroups: []string{"group1"}, + Blocked: false, + }, + }, + }, + expected: map[string][]string{ + "group1": {"user1", "user3"}, + "group2": {"user1", "user2"}, + "group3": {"user2"}, + "": {"user1", "user2", "user3"}, + }, + }, + { + name: "some users are blocked", + account: &Account{ + Users: map[string]*User{ + "user1": { + Id: "user1", + AutoGroups: []string{"group1", "group2"}, + Blocked: false, + }, + "user2": { + Id: "user2", + AutoGroups: []string{"group2", "group3"}, + Blocked: true, + }, + "user3": { + Id: "user3", + AutoGroups: []string{"group1", "group3"}, + Blocked: false, + }, + }, + }, + expected: map[string][]string{ + "group1": {"user1", "user3"}, + "group2": {"user1"}, + "group3": {"user3"}, + "": {"user1", "user3"}, + }, + }, + { + name: "all users are blocked", + account: &Account{ + Users: map[string]*User{ + "user1": { + Id: "user1", + AutoGroups: []string{"group1"}, + Blocked: true, + }, + "user2": { + Id: "user2", + AutoGroups: []string{"group2"}, + Blocked: true, + }, + }, + }, + expected: map[string][]string{}, + }, + { + name: "user with no auto groups", + account: &Account{ + Users: map[string]*User{ + "user1": { + Id: "user1", + AutoGroups: []string{}, + Blocked: false, + }, + "user2": { + Id: "user2", + AutoGroups: []string{"group1"}, + Blocked: false, + }, + }, + }, + expected: map[string][]string{ + "group1": {"user2"}, + "": {"user1", "user2"}, + }, + }, + { + name: "empty account", + account: &Account{ + Users: map[string]*User{}, + }, + expected: map[string][]string{}, + }, + { + name: "multiple users in same group", + account: &Account{ + Users: map[string]*User{ + "user1": { + Id: "user1", + AutoGroups: []string{"group1"}, + Blocked: false, + }, + "user2": { + Id: "user2", + AutoGroups: []string{"group1"}, + Blocked: false, + }, + "user3": { + Id: "user3", + AutoGroups: []string{"group1"}, + Blocked: false, + }, + }, + }, + expected: map[string][]string{ + "group1": {"user1", "user2", "user3"}, + "": {"user1", "user2", "user3"}, + }, + }, + { + name: "user in multiple groups with blocked users", + account: &Account{ + Users: map[string]*User{ + "user1": { + Id: "user1", + AutoGroups: []string{"group1", "group2", "group3"}, + Blocked: false, + }, + "user2": { + Id: "user2", + AutoGroups: []string{"group1", "group2"}, + Blocked: true, + }, + "user3": { + Id: "user3", + AutoGroups: []string{"group3"}, + Blocked: false, + }, + }, + }, + expected: map[string][]string{ + "group1": {"user1"}, + "group2": {"user1"}, + "group3": {"user1", "user3"}, + "": {"user1", "user3"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.account.GetActiveGroupUsers() + + // Check that the number of groups matches + assert.Equal(t, len(tt.expected), len(result), "number of groups should match") + + // Check each group's users + for groupID, expectedUsers := range tt.expected { + actualUsers, exists := result[groupID] + assert.True(t, exists, "group %s should exist in result", groupID) + assert.ElementsMatch(t, expectedUsers, actualUsers, "users in group %s should match", groupID) + } + + // Ensure no extra groups in result + for groupID := range result { + _, exists := tt.expected[groupID] + assert.True(t, exists, "unexpected group %s in result", groupID) + } + }) + } +} + func Test_FilterZoneRecordsForPeers(t *testing.T) { tests := []struct { name string diff --git a/management/server/types/holder.go b/management/server/types/holder.go index 3996db2b6..de8ac8110 100644 --- a/management/server/types/holder.go +++ b/management/server/types/holder.go @@ -25,16 +25,20 @@ func (h *Holder) GetAccount(id string) *Account { func (h *Holder) AddAccount(account *Account) { h.mu.Lock() defer h.mu.Unlock() + a := h.accounts[account.Id] + if a != nil && a.Network.CurrentSerial() >= account.Network.CurrentSerial() { + return + } h.accounts[account.Id] = account } -func (h *Holder) LoadOrStoreFunc(id string, accGetter func(context.Context, string) (*Account, error)) (*Account, error) { +func (h *Holder) LoadOrStoreFunc(ctx context.Context, id string, accGetter func(context.Context, string) (*Account, error)) (*Account, error) { h.mu.Lock() defer h.mu.Unlock() if acc, ok := h.accounts[id]; ok { return acc, nil } - account, err := accGetter(context.Background(), id) + account, err := accGetter(ctx, id) if err != nil { return nil, err } diff --git a/management/server/types/identity_provider.go b/management/server/types/identity_provider.go new file mode 100644 index 000000000..e809590de --- /dev/null +++ b/management/server/types/identity_provider.go @@ -0,0 +1,122 @@ +package types + +import ( + "errors" + "net/url" +) + +// Identity provider validation errors +var ( + ErrIdentityProviderNameRequired = errors.New("identity provider name is required") + ErrIdentityProviderTypeRequired = errors.New("identity provider type is required") + ErrIdentityProviderTypeUnsupported = errors.New("unsupported identity provider type") + ErrIdentityProviderIssuerRequired = errors.New("identity provider issuer is required") + ErrIdentityProviderIssuerInvalid = errors.New("identity provider issuer must be a valid URL") + ErrIdentityProviderClientIDRequired = errors.New("identity provider client ID is required") +) + +// IdentityProviderType is the type of identity provider +type IdentityProviderType string + +const ( + // IdentityProviderTypeOIDC is a generic OIDC identity provider + IdentityProviderTypeOIDC IdentityProviderType = "oidc" + // IdentityProviderTypeZitadel is the Zitadel identity provider + IdentityProviderTypeZitadel IdentityProviderType = "zitadel" + // IdentityProviderTypeEntra is the Microsoft Entra (Azure AD) identity provider + IdentityProviderTypeEntra IdentityProviderType = "entra" + // IdentityProviderTypeGoogle is the Google identity provider + IdentityProviderTypeGoogle IdentityProviderType = "google" + // IdentityProviderTypeOkta is the Okta identity provider + IdentityProviderTypeOkta IdentityProviderType = "okta" + // IdentityProviderTypePocketID is the PocketID identity provider + IdentityProviderTypePocketID IdentityProviderType = "pocketid" + // IdentityProviderTypeMicrosoft is the Microsoft identity provider + IdentityProviderTypeMicrosoft IdentityProviderType = "microsoft" + // IdentityProviderTypeAuthentik is the Authentik identity provider + IdentityProviderTypeAuthentik IdentityProviderType = "authentik" + // IdentityProviderTypeKeycloak is the Keycloak identity provider + IdentityProviderTypeKeycloak IdentityProviderType = "keycloak" +) + +// IdentityProvider represents an identity provider configuration +type IdentityProvider struct { + // ID is the unique identifier of the identity provider + ID string `gorm:"primaryKey"` + // AccountID is a reference to Account that this object belongs + AccountID string `json:"-" gorm:"index"` + // Type is the type of identity provider + Type IdentityProviderType + // Name is a human-readable name for the identity provider + Name string + // Issuer is the OIDC issuer URL + Issuer string + // ClientID is the OAuth2 client ID + ClientID string + // ClientSecret is the OAuth2 client secret + ClientSecret string +} + +// Copy returns a copy of the IdentityProvider +func (idp *IdentityProvider) Copy() *IdentityProvider { + return &IdentityProvider{ + ID: idp.ID, + AccountID: idp.AccountID, + Type: idp.Type, + Name: idp.Name, + Issuer: idp.Issuer, + ClientID: idp.ClientID, + ClientSecret: idp.ClientSecret, + } +} + +// EventMeta returns a map of metadata for activity events +func (idp *IdentityProvider) EventMeta() map[string]any { + return map[string]any{ + "name": idp.Name, + "type": string(idp.Type), + "issuer": idp.Issuer, + } +} + +// Validate validates the identity provider configuration +func (idp *IdentityProvider) Validate() error { + if idp.Name == "" { + return ErrIdentityProviderNameRequired + } + if idp.Type == "" { + return ErrIdentityProviderTypeRequired + } + if !idp.Type.IsValid() { + return ErrIdentityProviderTypeUnsupported + } + if !idp.Type.HasBuiltInIssuer() && idp.Issuer == "" { + return ErrIdentityProviderIssuerRequired + } + if idp.Issuer != "" { + parsedURL, err := url.Parse(idp.Issuer) + if err != nil || parsedURL.Scheme == "" || parsedURL.Host == "" { + return ErrIdentityProviderIssuerInvalid + } + } + if idp.ClientID == "" { + return ErrIdentityProviderClientIDRequired + } + return nil +} + +// IsValid checks if the given type is a supported identity provider type +func (t IdentityProviderType) IsValid() bool { + switch t { + case IdentityProviderTypeOIDC, IdentityProviderTypeZitadel, IdentityProviderTypeEntra, + IdentityProviderTypeGoogle, IdentityProviderTypeOkta, IdentityProviderTypePocketID, + IdentityProviderTypeMicrosoft, IdentityProviderTypeAuthentik, IdentityProviderTypeKeycloak: + return true + } + return false +} + +// HasBuiltInIssuer returns true for types that don't require an issuer URL +func (t IdentityProviderType) HasBuiltInIssuer() bool { + return t == IdentityProviderTypeGoogle || t == IdentityProviderTypeMicrosoft +} diff --git a/management/server/types/identity_provider_test.go b/management/server/types/identity_provider_test.go new file mode 100644 index 000000000..6ddc563f2 --- /dev/null +++ b/management/server/types/identity_provider_test.go @@ -0,0 +1,137 @@ +package types + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIdentityProvider_Validate(t *testing.T) { + tests := []struct { + name string + idp *IdentityProvider + expectedErr error + }{ + { + name: "valid OIDC provider", + idp: &IdentityProvider{ + Name: "Test Provider", + Type: IdentityProviderTypeOIDC, + Issuer: "https://example.com", + ClientID: "client-id", + }, + expectedErr: nil, + }, + { + name: "valid OIDC provider with path", + idp: &IdentityProvider{ + Name: "Test Provider", + Type: IdentityProviderTypeOIDC, + Issuer: "https://example.com/oauth2/issuer", + ClientID: "client-id", + }, + expectedErr: nil, + }, + { + name: "missing name", + idp: &IdentityProvider{ + Type: IdentityProviderTypeOIDC, + Issuer: "https://example.com", + ClientID: "client-id", + }, + expectedErr: ErrIdentityProviderNameRequired, + }, + { + name: "missing type", + idp: &IdentityProvider{ + Name: "Test Provider", + Issuer: "https://example.com", + ClientID: "client-id", + }, + expectedErr: ErrIdentityProviderTypeRequired, + }, + { + name: "invalid type", + idp: &IdentityProvider{ + Name: "Test Provider", + Type: "invalid", + Issuer: "https://example.com", + ClientID: "client-id", + }, + expectedErr: ErrIdentityProviderTypeUnsupported, + }, + { + name: "missing issuer for OIDC", + idp: &IdentityProvider{ + Name: "Test Provider", + Type: IdentityProviderTypeOIDC, + ClientID: "client-id", + }, + expectedErr: ErrIdentityProviderIssuerRequired, + }, + { + name: "invalid issuer URL - no scheme", + idp: &IdentityProvider{ + Name: "Test Provider", + Type: IdentityProviderTypeOIDC, + Issuer: "example.com", + ClientID: "client-id", + }, + expectedErr: ErrIdentityProviderIssuerInvalid, + }, + { + name: "invalid issuer URL - no host", + idp: &IdentityProvider{ + Name: "Test Provider", + Type: IdentityProviderTypeOIDC, + Issuer: "https://", + ClientID: "client-id", + }, + expectedErr: ErrIdentityProviderIssuerInvalid, + }, + { + name: "invalid issuer URL - just path", + idp: &IdentityProvider{ + Name: "Test Provider", + Type: IdentityProviderTypeOIDC, + Issuer: "/oauth2/issuer", + ClientID: "client-id", + }, + expectedErr: ErrIdentityProviderIssuerInvalid, + }, + { + name: "missing client ID", + idp: &IdentityProvider{ + Name: "Test Provider", + Type: IdentityProviderTypeOIDC, + Issuer: "https://example.com", + }, + expectedErr: ErrIdentityProviderClientIDRequired, + }, + { + name: "Google provider without issuer is valid", + idp: &IdentityProvider{ + Name: "Google SSO", + Type: IdentityProviderTypeGoogle, + ClientID: "client-id", + }, + expectedErr: nil, + }, + { + name: "Microsoft provider without issuer is valid", + idp: &IdentityProvider{ + Name: "Microsoft SSO", + Type: IdentityProviderTypeMicrosoft, + ClientID: "client-id", + }, + expectedErr: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.idp.Validate() + assert.Equal(t, tt.expectedErr, err) + }) + } +} diff --git a/management/server/types/network.go b/management/server/types/network.go index 0f45d410a..e42f6e1c7 100644 --- a/management/server/types/network.go +++ b/management/server/types/network.go @@ -39,6 +39,8 @@ type NetworkMap struct { FirewallRules []*FirewallRule RoutesFirewallRules []*RouteFirewallRule ForwardingRules []*ForwardingRule + AuthorizedUsers map[string]map[string]struct{} + EnableSSH bool } func (nm *NetworkMap) Merge(other *NetworkMap) { diff --git a/management/server/types/networkmap.go b/management/server/types/networkmap.go index c6551dcbb..5d90de9c6 100644 --- a/management/server/types/networkmap.go +++ b/management/server/types/networkmap.go @@ -47,14 +47,21 @@ func (a *Account) OnPeerAddedUpdNetworkMapCache(peerId string) error { if a.NetworkMapCache == nil { return nil } - return a.NetworkMapCache.OnPeerAddedIncremental(peerId) + return a.NetworkMapCache.OnPeerAddedIncremental(a, peerId) +} + +func (a *Account) OnPeersAddedUpdNetworkMapCache(peerIds ...string) { + if a.NetworkMapCache == nil { + return + } + a.NetworkMapCache.EnqueuePeersForIncrementalAdd(a, peerIds...) } func (a *Account) OnPeerDeletedUpdNetworkMapCache(peerId string) error { if a.NetworkMapCache == nil { return nil } - return a.NetworkMapCache.OnPeerDeleted(peerId) + return a.NetworkMapCache.OnPeerDeleted(a, peerId) } func (a *Account) UpdatePeerInNetworkMapCache(peer *nbpeer.Peer) { diff --git a/management/server/types/networkmap_golden_test.go b/management/server/types/networkmap_golden_test.go index fa2fe56b2..290cdf273 100644 --- a/management/server/types/networkmap_golden_test.go +++ b/management/server/types/networkmap_golden_test.go @@ -25,15 +25,12 @@ import ( "github.com/netbirdio/netbird/route" ) -// update flag is used to update the golden file. -// example: go test ./... -v -update -// var update = flag.Bool("update", false, "update golden files") - const ( numPeers = 100 devGroupID = "group-dev" opsGroupID = "group-ops" allGroupID = "group-all" + sshUsersGroupID = "group-ssh-users" routeID = route.ID("route-main") routeHA1ID = route.ID("route-ha-1") routeHA2ID = route.ID("route-ha-2") @@ -41,6 +38,7 @@ const ( policyIDAll = "policy-all" policyIDPosture = "policy-posture" policyIDDrop = "policy-drop" + policyIDSSH = "policy-ssh" postureCheckID = "posture-check-ver" networkResourceID = "res-database" networkID = "net-database" @@ -51,6 +49,9 @@ const ( offlinePeerID = "peer-99" // This peer will be completely offline. routingPeerID = "peer-95" // This peer is used for routing, it has a route to the network. testAccountID = "account-golden-test" + userAdminID = "user-admin" + userDevID = "user-dev" + userOpsID = "user-ops" ) func TestGetPeerNetworkMap_Golden(t *testing.T) { @@ -69,61 +70,34 @@ func TestGetPeerNetworkMap_Golden(t *testing.T) { resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() - networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) - - normalizeAndSortNetworkMap(networkMap) - - jsonData, err := json.MarshalIndent(networkMap, "", " ") - require.NoError(t, err, "error marshaling network map to JSON") - - goldenFilePath := filepath.Join("testdata", "networkmap_golden.json") - - t.Log("Update golden file...") - err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) - require.NoError(t, err) - err = os.WriteFile(goldenFilePath, jsonData, 0644) - require.NoError(t, err) - - expectedJSON, err := os.ReadFile(goldenFilePath) - require.NoError(t, err, "error reading golden file") - - require.JSONEq(t, string(expectedJSON), string(jsonData), "resulted network map from OLD method does not match golden file") -} - -func TestGetPeerNetworkMap_Golden_New(t *testing.T) { - account := createTestAccountWithEntities() - - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - - if peerID == offlinePeerID { - continue - } - validatedPeersMap[peerID] = struct{}{} - } + legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) + normalizeAndSortNetworkMap(legacyNetworkMap) + legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ") + require.NoError(t, err, "error marshaling legacy network map to JSON") builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + normalizeAndSortNetworkMap(newNetworkMap) + newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ") + require.NoError(t, err, "error marshaling new network map to JSON") - normalizeAndSortNetworkMap(networkMap) + if string(legacyJSON) != string(newJSON) { + legacyFilePath := filepath.Join("testdata", "networkmap_golden.json") + newFilePath := filepath.Join("testdata", "networkmap_golden_new.json") - jsonData, err := json.MarshalIndent(networkMap, "", " ") - require.NoError(t, err, "error marshaling network map to JSON") + err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755) + require.NoError(t, err) - goldenFilePath := filepath.Join("testdata", "networkmap_golden_new.json") + err = os.WriteFile(legacyFilePath, legacyJSON, 0644) + require.NoError(t, err) + t.Logf("Saved legacy network map to %s", legacyFilePath) - t.Log("Update golden file...") - err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) - require.NoError(t, err) - err = os.WriteFile(goldenFilePath, jsonData, 0644) - require.NoError(t, err) + err = os.WriteFile(newFilePath, newJSON, 0644) + require.NoError(t, err) + t.Logf("Saved new network map to %s", newFilePath) - expectedJSON, err := os.ReadFile(goldenFilePath) - require.NoError(t, err, "error reading golden file") - - require.JSONEq(t, string(expectedJSON), string(jsonData), "resulted network map from NEW builder does not match golden file") + require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps from legacy and new builder do not match") + } } func BenchmarkGetPeerNetworkMap(b *testing.B) { @@ -141,7 +115,7 @@ func BenchmarkGetPeerNetworkMap(b *testing.B) { b.Run("old builder", func(b *testing.B) { for range b.N { for _, peerID := range peerIDs { - _ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil) + _ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers()) } } }) @@ -169,6 +143,8 @@ func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) { validatedPeersMap[peerID] = struct{}{} } + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + newPeerID := "peer-new-101" newPeerIP := net.IP{100, 64, 1, 1} newPeer := &nbpeer.Peer{ @@ -201,92 +177,36 @@ func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) { resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() - networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) + normalizeAndSortNetworkMap(legacyNetworkMap) + legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ") + require.NoError(t, err, "error marshaling legacy network map to JSON") - normalizeAndSortNetworkMap(networkMap) - - jsonData, err := json.MarshalIndent(networkMap, "", " ") - require.NoError(t, err, "error marshaling network map to JSON") - - goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_new_peer.json") - - t.Log("Update golden file with new peer...") - err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) - require.NoError(t, err) - err = os.WriteFile(goldenFilePath, jsonData, 0644) - require.NoError(t, err) - - expectedJSON, err := os.ReadFile(goldenFilePath) - require.NoError(t, err, "error reading golden file") - - require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with new peer does not match golden file") -} - -func TestGetPeerNetworkMap_Golden_New_WithOnPeerAdded(t *testing.T) { - account := createTestAccountWithEntities() - - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - if peerID == offlinePeerID { - continue - } - validatedPeersMap[peerID] = struct{}{} - } - - builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - - newPeerID := "peer-new-101" - newPeerIP := net.IP{100, 64, 1, 1} - newPeer := &nbpeer.Peer{ - ID: newPeerID, - IP: newPeerIP, - Key: fmt.Sprintf("key-%s", newPeerID), - DNSLabel: "peernew101", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, - UserID: "user-admin", - Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, - LastLogin: func() *time.Time { t := time.Now(); return &t }(), - } - - account.Peers[newPeerID] = newPeer - - if devGroup, exists := account.Groups[devGroupID]; exists { - devGroup.Peers = append(devGroup.Peers, newPeerID) - } - - if allGroup, exists := account.Groups[allGroupID]; exists { - allGroup.Peers = append(allGroup.Peers, newPeerID) - } - - validatedPeersMap[newPeerID] = struct{}{} - - if account.Network != nil { - account.Network.Serial++ - } - - err := builder.OnPeerAddedIncremental(newPeerID) + err = builder.OnPeerAddedIncremental(account, newPeerID) require.NoError(t, err, "error adding peer to cache") - networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + normalizeAndSortNetworkMap(newNetworkMap) + newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ") + require.NoError(t, err, "error marshaling new network map to JSON") - normalizeAndSortNetworkMap(networkMap) + if string(legacyJSON) != string(newJSON) { + legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_new_peer.json") + newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded.json") - jsonData, err := json.MarshalIndent(networkMap, "", " ") - require.NoError(t, err, "error marshaling network map to JSON") + err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755) + require.NoError(t, err) - goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded.json") - t.Log("Update golden file with OnPeerAdded...") - err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) - require.NoError(t, err) - err = os.WriteFile(goldenFilePath, jsonData, 0644) - require.NoError(t, err) + err = os.WriteFile(legacyFilePath, legacyJSON, 0644) + require.NoError(t, err) + t.Logf("Saved legacy network map to %s", legacyFilePath) - expectedJSON, err := os.ReadFile(goldenFilePath) - require.NoError(t, err, "error reading golden file") + err = os.WriteFile(newFilePath, newJSON, 0644) + require.NoError(t, err) + t.Logf("Saved new network map to %s", newFilePath) - require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerAdded does not match golden file") + require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with new peer from legacy and new builder do not match") + } } func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) { @@ -320,7 +240,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) { b.Run("old builder after add", func(b *testing.B) { for i := 0; i < b.N; i++ { for _, testingPeerID := range peerIDs { - _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil) + _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers()) } } }) @@ -328,7 +248,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) { b.ResetTimer() b.Run("new builder after add", func(b *testing.B) { for i := 0; i < b.N; i++ { - _ = builder.OnPeerAddedIncremental(newPeerID) + _ = builder.OnPeerAddedIncremental(account, newPeerID) for _, testingPeerID := range peerIDs { _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) } @@ -349,6 +269,8 @@ func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) { validatedPeersMap[peerID] = struct{}{} } + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + newRouterID := "peer-new-router-102" newRouterIP := net.IP{100, 64, 1, 2} newRouter := &nbpeer.Peer{ @@ -395,106 +317,36 @@ func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) { resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() - networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) + normalizeAndSortNetworkMap(legacyNetworkMap) + legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ") + require.NoError(t, err, "error marshaling legacy network map to JSON") - normalizeAndSortNetworkMap(networkMap) - - jsonData, err := json.MarshalIndent(networkMap, "", " ") - require.NoError(t, err, "error marshaling network map to JSON") - - goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_new_router.json") - - t.Log("Update golden file with new router...") - err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) - require.NoError(t, err) - err = os.WriteFile(goldenFilePath, jsonData, 0644) - require.NoError(t, err) - - expectedJSON, err := os.ReadFile(goldenFilePath) - require.NoError(t, err, "error reading golden file") - - require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with new router does not match golden file") -} - -func TestGetPeerNetworkMap_Golden_New_WithOnPeerAddedRouter(t *testing.T) { - account := createTestAccountWithEntities() - - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - if peerID == offlinePeerID { - continue - } - validatedPeersMap[peerID] = struct{}{} - } - - builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - - newRouterID := "peer-new-router-102" - newRouterIP := net.IP{100, 64, 1, 2} - newRouter := &nbpeer.Peer{ - ID: newRouterID, - IP: newRouterIP, - Key: fmt.Sprintf("key-%s", newRouterID), - DNSLabel: "newrouter102", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, - UserID: "user-admin", - Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, - LastLogin: func() *time.Time { t := time.Now(); return &t }(), - } - - account.Peers[newRouterID] = newRouter - - if opsGroup, exists := account.Groups[opsGroupID]; exists { - opsGroup.Peers = append(opsGroup.Peers, newRouterID) - } - if allGroup, exists := account.Groups[allGroupID]; exists { - allGroup.Peers = append(allGroup.Peers, newRouterID) - } - - newRoute := &route.Route{ - ID: route.ID("route-new-router"), - Network: netip.MustParsePrefix("172.16.0.0/24"), - Peer: newRouter.Key, - PeerID: newRouterID, - Description: "Route from new router", - Enabled: true, - PeerGroups: []string{opsGroupID}, - Groups: []string{devGroupID, opsGroupID}, - AccessControlGroups: []string{devGroupID}, - AccountID: account.Id, - } - account.Routes[newRoute.ID] = newRoute - - validatedPeersMap[newRouterID] = struct{}{} - - if account.Network != nil { - account.Network.Serial++ - } - - err := builder.OnPeerAddedIncremental(newRouterID) + err = builder.OnPeerAddedIncremental(account, newRouterID) require.NoError(t, err, "error adding router to cache") - networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + normalizeAndSortNetworkMap(newNetworkMap) + newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ") + require.NoError(t, err, "error marshaling new network map to JSON") - normalizeAndSortNetworkMap(networkMap) + if string(legacyJSON) != string(newJSON) { + legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_new_router.json") + newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded_router.json") - jsonData, err := json.MarshalIndent(networkMap, "", " ") - require.NoError(t, err, "error marshaling network map to JSON") + err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755) + require.NoError(t, err) - goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded_router.json") + err = os.WriteFile(legacyFilePath, legacyJSON, 0644) + require.NoError(t, err) + t.Logf("Saved legacy network map to %s", legacyFilePath) - t.Log("Update golden file with OnPeerAdded router...") - err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) - require.NoError(t, err) - err = os.WriteFile(goldenFilePath, jsonData, 0644) - require.NoError(t, err) + err = os.WriteFile(newFilePath, newJSON, 0644) + require.NoError(t, err) + t.Logf("Saved new network map to %s", newFilePath) - expectedJSON, err := os.ReadFile(goldenFilePath) - require.NoError(t, err, "error reading golden file") - - require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerAdded router does not match golden file") + require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with new router from legacy and new builder do not match") + } } func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) { @@ -550,7 +402,7 @@ func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) { b.Run("old builder after add", func(b *testing.B) { for i := 0; i < b.N; i++ { for _, testingPeerID := range peerIDs { - _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil) + _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers()) } } }) @@ -558,7 +410,7 @@ func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) { b.ResetTimer() b.Run("new builder after add", func(b *testing.B) { for i := 0; i < b.N; i++ { - _ = builder.OnPeerAddedIncremental(newRouterID) + _ = builder.OnPeerAddedIncremental(account, newRouterID) for _, testingPeerID := range peerIDs { _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) } @@ -579,7 +431,9 @@ func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) { validatedPeersMap[peerID] = struct{}{} } - deletedPeerID := "peer-25" // peer from devs group + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + + deletedPeerID := "peer-25" delete(account.Peers, deletedPeerID) @@ -604,85 +458,36 @@ func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) { resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() - networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) + normalizeAndSortNetworkMap(legacyNetworkMap) + legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ") + require.NoError(t, err, "error marshaling legacy network map to JSON") - normalizeAndSortNetworkMap(networkMap) - - jsonData, err := json.MarshalIndent(networkMap, "", " ") - require.NoError(t, err, "error marshaling network map to JSON") - - goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_peer.json") - - t.Log("Update golden file with deleted peer...") - err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) - require.NoError(t, err) - err = os.WriteFile(goldenFilePath, jsonData, 0644) - require.NoError(t, err) - - expectedJSON, err := os.ReadFile(goldenFilePath) - require.NoError(t, err, "error reading golden file") - - require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with deleted peer does not match golden file") -} - -func TestGetPeerNetworkMap_Golden_New_WithOnPeerDeleted(t *testing.T) { - account := createTestAccountWithEntities() - - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - if peerID == offlinePeerID { - continue - } - validatedPeersMap[peerID] = struct{}{} - } - - builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - - deletedPeerID := "peer-25" // devs group peer - - delete(account.Peers, deletedPeerID) - - if devGroup, exists := account.Groups[devGroupID]; exists { - devGroup.Peers = slices.DeleteFunc(devGroup.Peers, func(id string) bool { - return id == deletedPeerID - }) - } - - if allGroup, exists := account.Groups[allGroupID]; exists { - allGroup.Peers = slices.DeleteFunc(allGroup.Peers, func(id string) bool { - return id == deletedPeerID - }) - } - - delete(validatedPeersMap, deletedPeerID) - - if account.Network != nil { - account.Network.Serial++ - } - - err := builder.OnPeerDeleted(deletedPeerID) + err = builder.OnPeerDeleted(account, deletedPeerID) require.NoError(t, err, "error deleting peer from cache") - networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + normalizeAndSortNetworkMap(newNetworkMap) + newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ") + require.NoError(t, err, "error marshaling new network map to JSON") - normalizeAndSortNetworkMap(networkMap) + if string(legacyJSON) != string(newJSON) { + legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_peer.json") + newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeerdeleted.json") - jsonData, err := json.MarshalIndent(networkMap, "", " ") - require.NoError(t, err, "error marshaling network map to JSON") + err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755) + require.NoError(t, err) - goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeerdeleted.json") - t.Log("Update golden file with OnPeerDeleted...") - err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) - require.NoError(t, err) - err = os.WriteFile(goldenFilePath, jsonData, 0644) - require.NoError(t, err) + err = os.WriteFile(legacyFilePath, legacyJSON, 0644) + require.NoError(t, err) + t.Logf("Saved legacy network map to %s", legacyFilePath) - expectedJSON, err := os.ReadFile(goldenFilePath) - require.NoError(t, err, "error reading golden file") + err = os.WriteFile(newFilePath, newJSON, 0644) + require.NoError(t, err) + t.Logf("Saved new network map to %s", newFilePath) - require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerDeleted does not match golden file") + require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with deleted peer from legacy and new builder do not match") + } } func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) { @@ -698,7 +503,9 @@ func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) { validatedPeersMap[peerID] = struct{}{} } - deletedRouterID := "peer-75" // router peer + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + + deletedRouterID := "peer-75" var affectedRoute *route.Route for _, r := range account.Routes { @@ -730,93 +537,36 @@ func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) { resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() - networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) + normalizeAndSortNetworkMap(legacyNetworkMap) + legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ") + require.NoError(t, err, "error marshaling legacy network map to JSON") - normalizeAndSortNetworkMap(networkMap) - - jsonData, err := json.MarshalIndent(networkMap, "", " ") - require.NoError(t, err, "error marshaling network map to JSON") - - goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_router_peer.json") - - t.Log("Update golden file with deleted peer...") - err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) - require.NoError(t, err) - err = os.WriteFile(goldenFilePath, jsonData, 0644) - require.NoError(t, err) - - expectedJSON, err := os.ReadFile(goldenFilePath) - require.NoError(t, err, "error reading golden file") - - require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with deleted peer does not match golden file") -} - -func TestGetPeerNetworkMap_Golden_New_WithDeletedRouterPeer(t *testing.T) { - account := createTestAccountWithEntities() - - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - if peerID == offlinePeerID { - continue - } - validatedPeersMap[peerID] = struct{}{} - } - - builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - - deletedRouterID := "peer-75" // router peer - - var affectedRoute *route.Route - for _, r := range account.Routes { - if r.PeerID == deletedRouterID { - affectedRoute = r - break - } - } - require.NotNil(t, affectedRoute, "Router peer should have a route") - - for _, group := range account.Groups { - group.Peers = slices.DeleteFunc(group.Peers, func(id string) bool { - return id == deletedRouterID - }) - } - for routeID, r := range account.Routes { - if r.Peer == account.Peers[deletedRouterID].Key || r.PeerID == deletedRouterID { - delete(account.Routes, routeID) - } - } - delete(account.Peers, deletedRouterID) - delete(validatedPeersMap, deletedRouterID) - - if account.Network != nil { - account.Network.Serial++ - } - - err := builder.OnPeerDeleted(deletedRouterID) + err = builder.OnPeerDeleted(account, deletedRouterID) require.NoError(t, err, "error deleting routing peer from cache") - networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + normalizeAndSortNetworkMap(newNetworkMap) + newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ") + require.NoError(t, err, "error marshaling new network map to JSON") - normalizeAndSortNetworkMap(networkMap) + if string(legacyJSON) != string(newJSON) { + legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_router_peer.json") + newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_deleted_router.json") - jsonData, err := json.MarshalIndent(networkMap, "", " ") - require.NoError(t, err) + err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755) + require.NoError(t, err) - goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_deleted_router.json") + err = os.WriteFile(legacyFilePath, legacyJSON, 0644) + require.NoError(t, err) + t.Logf("Saved legacy network map to %s", legacyFilePath) - t.Log("Update golden file with deleted router...") - err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) - require.NoError(t, err) - err = os.WriteFile(goldenFilePath, jsonData, 0644) - require.NoError(t, err) + err = os.WriteFile(newFilePath, newJSON, 0644) + require.NoError(t, err) + t.Logf("Saved new network map to %s", newFilePath) - expectedJSON, err := os.ReadFile(goldenFilePath) - require.NoError(t, err) - - require.JSONEq(t, string(expectedJSON), string(jsonData), - "network map after deleting router does not match golden file") + require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with deleted router from legacy and new builder do not match") + } } func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) { @@ -847,7 +597,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) { b.Run("old builder after delete", func(b *testing.B) { for i := 0; i < b.N; i++ { for _, testingPeerID := range peerIDs { - _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil) + _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers()) } } }) @@ -855,7 +605,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) { b.ResetTimer() b.Run("new builder after delete", func(b *testing.B) { for i := 0; i < b.N; i++ { - _ = builder.OnPeerDeleted(deletedPeerID) + _ = builder.OnPeerDeleted(account, deletedPeerID) for _, testingPeerID := range peerIDs { _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) } @@ -927,6 +677,54 @@ func normalizeAndSortNetworkMap(networkMap *types.NetworkMap) { } } +type networkMapJSON struct { + Peers []*nbpeer.Peer `json:"Peers"` + Network *types.Network `json:"Network"` + Routes []*route.Route `json:"Routes"` + DNSConfig dns.Config `json:"DNSConfig"` + OfflinePeers []*nbpeer.Peer `json:"OfflinePeers"` + FirewallRules []*types.FirewallRule `json:"FirewallRules"` + RoutesFirewallRules []*types.RouteFirewallRule `json:"RoutesFirewallRules"` + ForwardingRules []*types.ForwardingRule `json:"ForwardingRules"` + AuthorizedUsers map[string][]string `json:"AuthorizedUsers,omitempty"` + EnableSSH bool `json:"EnableSSH"` +} + +func toNetworkMapJSON(nm *types.NetworkMap) *networkMapJSON { + result := &networkMapJSON{ + Peers: nm.Peers, + Network: nm.Network, + Routes: nm.Routes, + DNSConfig: nm.DNSConfig, + OfflinePeers: nm.OfflinePeers, + FirewallRules: nm.FirewallRules, + RoutesFirewallRules: nm.RoutesFirewallRules, + ForwardingRules: nm.ForwardingRules, + EnableSSH: nm.EnableSSH, + } + + if len(nm.AuthorizedUsers) > 0 { + result.AuthorizedUsers = make(map[string][]string) + localUsers := make([]string, 0, len(nm.AuthorizedUsers)) + for localUser := range nm.AuthorizedUsers { + localUsers = append(localUsers, localUser) + } + sort.Strings(localUsers) + + for _, localUser := range localUsers { + userIDs := nm.AuthorizedUsers[localUser] + sortedUserIDs := make([]string, 0, len(userIDs)) + for userID := range userIDs { + sortedUserIDs = append(sortedUserIDs, userID) + } + sort.Strings(sortedUserIDs) + result.AuthorizedUsers[localUser] = sortedUserIDs + } + } + + return result +} + func createTestAccountWithEntities() *types.Account { peers := make(map[string]*nbpeer.Peer) devGroupPeers, opsGroupPeers, allGroupPeers := []string{}, []string{}, []string{} @@ -962,9 +760,10 @@ func createTestAccountWithEntities() *types.Account { } groups := map[string]*types.Group{ - allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers}, - devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers}, - opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers}, + allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers}, + devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers}, + opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers}, + sshUsersGroupID: {ID: sshUsersGroupID, Name: "SSH Users", Peers: []string{}}, } policies := []*types.Policy{ @@ -1002,6 +801,15 @@ func createTestAccountWithEntities() *types.Account { Sources: []string{opsGroupID}, DestinationResource: types.Resource{ID: networkResourceID}, }}, }, + { + ID: policyIDSSH, Name: "SSH Access Policy", Enabled: true, + Rules: []*types.PolicyRule{{ + ID: policyIDSSH, Name: "Allow SSH to Ops", Enabled: true, Action: types.PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolNetbirdSSH, Bidirectional: false, + Sources: []string{devGroupID}, Destinations: []string{opsGroupID}, + AuthorizedGroups: map[string][]string{sshUsersGroupID: {"root", "admin"}}, + }}, + }, } routes := map[route.ID]*route.Route{ @@ -1034,8 +842,15 @@ func createTestAccountWithEntities() *types.Account { }, } + users := map[string]*types.User{ + userAdminID: {Id: userAdminID, Role: types.UserRoleAdmin, IsServiceUser: false, AccountID: testAccountID, AutoGroups: []string{allGroupID}}, + userDevID: {Id: userDevID, Role: types.UserRoleUser, IsServiceUser: false, AccountID: testAccountID, AutoGroups: []string{sshUsersGroupID, devGroupID}}, + userOpsID: {Id: userOpsID, Role: types.UserRoleUser, IsServiceUser: false, AccountID: testAccountID, AutoGroups: []string{sshUsersGroupID, opsGroupID}}, + } + account := &types.Account{ Id: testAccountID, Peers: peers, Groups: groups, Policies: policies, Routes: routes, + Users: users, Network: &types.Network{ Identifier: "net-golden-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1, }, @@ -1071,6 +886,88 @@ func createTestAccountWithEntities() *types.Account { return account } +func TestGetPeerNetworkMap_Golden_New_WithOnPeerAddedRouter_Batched(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + + newRouterID := "peer-new-router-102" + newRouterIP := net.IP{100, 64, 1, 2} + newRouter := &nbpeer.Peer{ + ID: newRouterID, + IP: newRouterIP, + Key: fmt.Sprintf("key-%s", newRouterID), + DNSLabel: "newrouter102", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, + LastLogin: func() *time.Time { t := time.Now(); return &t }(), + } + + account.Peers[newRouterID] = newRouter + + if opsGroup, exists := account.Groups[opsGroupID]; exists { + opsGroup.Peers = append(opsGroup.Peers, newRouterID) + } + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = append(allGroup.Peers, newRouterID) + } + + newRoute := &route.Route{ + ID: route.ID("route-new-router"), + Network: netip.MustParsePrefix("172.16.0.0/24"), + Peer: newRouter.Key, + PeerID: newRouterID, + Description: "Route from new router", + Enabled: true, + PeerGroups: []string{opsGroupID}, + Groups: []string{devGroupID, opsGroupID}, + AccessControlGroups: []string{devGroupID}, + AccountID: account.Id, + } + account.Routes[newRoute.ID] = newRoute + + validatedPeersMap[newRouterID] = struct{}{} + + if account.Network != nil { + account.Network.Serial++ + } + + builder.EnqueuePeersForIncrementalAdd(account, newRouterID) + + time.Sleep(100 * time.Millisecond) + + networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded_router.json") + + t.Log("Update golden file with OnPeerAdded router...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerAdded router does not match golden file") +} + func createAccountFromFile() (*types.Account, error) { accraw := filepath.Join("testdata", "account_cnlf3j3l0ubs738o5d4g.json") data, err := os.ReadFile(accraw) @@ -1343,7 +1240,7 @@ func BenchmarkGetPeerNetworkMapCompactCached(b *testing.B) { b.Run("Legacy", func(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _ = account.GetPeerNetworkMap(ctx, testingPeerID, customZone, validatedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) + _ = account.GetPeerNetworkMap(ctx, testingPeerID, customZone, validatedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers()) } }) b.Run("LegacyCompacted", func(b *testing.B) { diff --git a/management/server/types/networkmapbuilder.go b/management/server/types/networkmapbuilder.go index 877c903b9..93a432193 100644 --- a/management/server/types/networkmapbuilder.go +++ b/management/server/types/networkmapbuilder.go @@ -7,12 +7,12 @@ import ( "strconv" "strings" "sync" - "sync/atomic" "time" log "github.com/sirupsen/logrus" "golang.org/x/exp/maps" + "github.com/netbirdio/netbird/client/ssh/auth" nbdns "github.com/netbirdio/netbird/dns" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" @@ -27,6 +27,9 @@ const ( v6AllWildcard = "::/0" fw = "fw:" rfw = "route-fw:" + + szAddPeerBatch = 10 + maxPeerAddRetries = 20 ) type NetworkMapCache struct { @@ -47,6 +50,10 @@ type NetworkMapCache struct { peerDNS map[string]*nbdns.Config peerFirewallRulesCompact map[string][]*FirewallRule peerRoutesCompact map[string][]*route.Route + peerSSH map[string]*PeerSSHView + + groupIDToUserIDs map[string][]string + allowedUserIDs map[string]struct{} resourceRouters map[string]map[string]*routerTypes.NetworkRouter resourcePolicies map[string][]*Policy @@ -76,41 +83,64 @@ type PeerRoutesView struct { RouteFirewallRuleIDs []string } +type PeerSSHView struct { + EnableSSH bool + AuthorizedUsers map[string]map[string]struct{} +} + type NetworkMapBuilder struct { - account atomic.Pointer[Account] + account *Account cache *NetworkMapCache validatedPeers map[string]struct{} + + apb addPeerBatch +} + +type addPeerBatch struct { + mu sync.Mutex + sg *sync.Cond + ids []string + la *Account + retryCount map[string]int } func NewNetworkMapBuilder(account *Account, validatedPeers map[string]struct{}) *NetworkMapBuilder { builder := &NetworkMapBuilder{ cache: &NetworkMapCache{ - globalRoutes: make(map[route.ID]*route.Route), - globalRules: make(map[string]*FirewallRule), - globalRouteRules: make(map[string]*RouteFirewallRule), - globalPeers: make(map[string]*nbpeer.Peer), - groupToPeers: make(map[string][]string), - peerToGroups: make(map[string][]string), - policyToRules: make(map[string][]*PolicyRule), - groupToPolicies: make(map[string][]*Policy), - groupToRoutes: make(map[string][]*route.Route), - peerToRoutes: make(map[string][]*route.Route), + globalRoutes: make(map[route.ID]*route.Route), + globalRules: make(map[string]*FirewallRule), + globalRouteRules: make(map[string]*RouteFirewallRule), + globalPeers: make(map[string]*nbpeer.Peer), + groupToPeers: make(map[string][]string), + peerToGroups: make(map[string][]string), + policyToRules: make(map[string][]*PolicyRule), + groupToPolicies: make(map[string][]*Policy), + groupToRoutes: make(map[string][]*route.Route), + peerToRoutes: make(map[string][]*route.Route), peerACLs: make(map[string]*PeerACLView), peerRoutes: make(map[string]*PeerRoutesView), peerDNS: make(map[string]*nbdns.Config), + peerSSH: make(map[string]*PeerSSHView), + groupIDToUserIDs: make(map[string][]string), + allowedUserIDs: make(map[string]struct{}), peerFirewallRulesCompact: make(map[string][]*FirewallRule), peerRoutesCompact: make(map[string][]*route.Route), globalResources: make(map[string]*resourceTypes.NetworkResource), - acgToRoutes: make(map[string]map[route.ID]*RouteOwnerInfo), - noACGRoutes: make(map[route.ID]*RouteOwnerInfo), + acgToRoutes: make(map[string]map[route.ID]*RouteOwnerInfo), + noACGRoutes: make(map[route.ID]*RouteOwnerInfo), }, validatedPeers: make(map[string]struct{}), } - builder.account.Store(account) + builder.apb.sg = sync.NewCond(&builder.apb.mu) + builder.apb.ids = make([]string, 0, szAddPeerBatch) + builder.apb.la = account + builder.apb.retryCount = make(map[string]int) + maps.Copy(builder.validatedPeers, validatedPeers) builder.initialBuild(account) + go builder.incAddPeerLoop() return builder } @@ -118,6 +148,8 @@ func (b *NetworkMapBuilder) initialBuild(account *Account) { b.cache.mu.Lock() defer b.cache.mu.Unlock() + b.account = account + start := time.Now() b.buildGlobalIndexes(account) @@ -150,9 +182,15 @@ func (b *NetworkMapBuilder) buildGlobalIndexes(account *Account) { clear(b.cache.peerToRoutes) clear(b.cache.acgToRoutes) clear(b.cache.noACGRoutes) + clear(b.cache.groupIDToUserIDs) + clear(b.cache.allowedUserIDs) + clear(b.cache.peerSSH) maps.Copy(b.cache.globalPeers, account.Peers) + b.cache.groupIDToUserIDs = account.GetActiveGroupUsers() + b.cache.allowedUserIDs = b.buildAllowedUserIDs(account) + for groupID, group := range account.Groups { peersCopy := make([]string, len(group.Peers)) copy(peersCopy, group.Peers) @@ -227,7 +265,7 @@ func (b *NetworkMapBuilder) buildPeerACLView(account *Account, peerID string) { return } - allPotentialPeers, firewallRules := b.getPeerConnectionResources(account, peer, b.validatedPeers) + allPotentialPeers, firewallRules, authorizedUsers, sshEnabled := b.getPeerConnectionResources(account, peer, b.validatedPeers) isRouter, networkResourcesRoutes, sourcePeers := b.getNetworkResourcesForPeer(account, peer) @@ -260,12 +298,17 @@ func (b *NetworkMapBuilder) buildPeerACLView(account *Account, peerID string) { compactedRules := compactFirewallRules(firewallRules) b.cache.peerFirewallRulesCompact[peerID] = compactedRules + b.cache.peerSSH[peerID] = &PeerSSHView{ + EnableSSH: sshEnabled, + AuthorizedUsers: authorizedUsers, + } } func (b *NetworkMapBuilder) getPeerConnectionResources(account *Account, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}, -) ([]*nbpeer.Peer, []*FirewallRule) { +) ([]*nbpeer.Peer, []*FirewallRule, map[string]map[string]struct{}, bool) { peerID := peer.ID + ctx := context.Background() peerGroups := b.cache.peerToGroups[peerID] peerGroupsMap := make(map[string]struct{}, len(peerGroups)) @@ -278,9 +321,15 @@ func (b *NetworkMapBuilder) getPeerConnectionResources(account *Account, peer *n fwRules := make([]*FirewallRule, 0) peers := make([]*nbpeer.Peer, 0) + authorizedUsers := make(map[string]map[string]struct{}) + sshEnabled := false + for _, group := range peerGroups { policies := b.cache.groupToPolicies[group] for _, policy := range policies { + if isValid := account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID); !isValid { + continue + } rules := b.cache.policyToRules[policy.ID] for _, rule := range rules { var sourcePeers, destinationPeers []*nbpeer.Peer @@ -323,13 +372,13 @@ func (b *NetworkMapBuilder) getPeerConnectionResources(account *Account, peer *n if rule.Bidirectional { if peerInSources { b.generateResourcescached( - account, rule, destinationPeers, FirewallRuleDirectionIN, + rule, destinationPeers, FirewallRuleDirectionIN, peer, &peers, &fwRules, peersExists, rulesExists, ) } if peerInDestinations { b.generateResourcescached( - account, rule, sourcePeers, FirewallRuleDirectionOUT, + rule, sourcePeers, FirewallRuleDirectionOUT, peer, &peers, &fwRules, peersExists, rulesExists, ) } @@ -337,22 +386,58 @@ func (b *NetworkMapBuilder) getPeerConnectionResources(account *Account, peer *n if peerInSources { b.generateResourcescached( - account, rule, destinationPeers, FirewallRuleDirectionOUT, + rule, destinationPeers, FirewallRuleDirectionOUT, peer, &peers, &fwRules, peersExists, rulesExists, ) } if peerInDestinations { b.generateResourcescached( - account, rule, sourcePeers, FirewallRuleDirectionIN, + rule, sourcePeers, FirewallRuleDirectionIN, peer, &peers, &fwRules, peersExists, rulesExists, ) + + if rule.Protocol == PolicyRuleProtocolNetbirdSSH { + sshEnabled = true + switch { + case len(rule.AuthorizedGroups) > 0: + for groupID, localUsers := range rule.AuthorizedGroups { + userIDs, ok := b.cache.groupIDToUserIDs[groupID] + if !ok { + continue + } + + if len(localUsers) == 0 { + localUsers = []string{auth.Wildcard} + } + + for _, localUser := range localUsers { + if authorizedUsers[localUser] == nil { + authorizedUsers[localUser] = make(map[string]struct{}) + } + for _, userID := range userIDs { + authorizedUsers[localUser][userID] = struct{}{} + } + } + } + case rule.AuthorizedUser != "": + if authorizedUsers[auth.Wildcard] == nil { + authorizedUsers[auth.Wildcard] = make(map[string]struct{}) + } + authorizedUsers[auth.Wildcard][rule.AuthorizedUser] = struct{}{} + default: + authorizedUsers[auth.Wildcard] = maps.Clone(b.cache.allowedUserIDs) + } + } else if policyRuleImpliesLegacySSH(rule) && peer.SSHEnabled { + sshEnabled = true + authorizedUsers[auth.Wildcard] = maps.Clone(b.cache.allowedUserIDs) + } } } } } - return peers, fwRules + return peers, fwRules, authorizedUsers, sshEnabled } func (b *NetworkMapBuilder) isPeerInGroupscached(groupIDs []string, peerGroupsMap map[string]struct{}) bool { @@ -405,14 +490,9 @@ func (b *NetworkMapBuilder) getPeersFromGroupscached(account *Account, groupIDs } func (b *NetworkMapBuilder) generateResourcescached( - account *Account, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int, targetPeer *nbpeer.Peer, + rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int, targetPeer *nbpeer.Peer, peers *[]*nbpeer.Peer, rules *[]*FirewallRule, peersExists map[string]struct{}, rulesExists map[string]struct{}, ) { - isAll := false - if allGroup, err := account.GetGroupAll(); err == nil { - isAll = (len(allGroup.Peers) - 1) == len(groupPeers) - } - for _, peer := range groupPeers { if peer == nil { continue @@ -427,11 +507,7 @@ func (b *NetworkMapBuilder) generateResourcescached( PeerIP: peer.IP.String(), Direction: direction, Action: string(rule.Action), - Protocol: string(rule.Protocol), - } - - if isAll { - fr.PeerIP = allPeers + Protocol: firewallRuleProtocol(rule.Protocol), } var s strings.Builder @@ -942,8 +1018,29 @@ func (b *NetworkMapBuilder) getPeerNSGroups(account *Account, peerID string, che return peerNSGroups } -func (b *NetworkMapBuilder) UpdateAccountPointer(account *Account) { - b.account.Store(account) +func (b *NetworkMapBuilder) buildAllowedUserIDs(account *Account) map[string]struct{} { + users := make(map[string]struct{}) + for _, nbUser := range account.Users { + if !nbUser.IsBlocked() && !nbUser.IsServiceUser { + users[nbUser.Id] = struct{}{} + } + } + return users +} + +func firewallRuleProtocol(protocol PolicyRuleProtocolType) string { + if protocol == PolicyRuleProtocolNetbirdSSH { + return string(PolicyRuleProtocolTCP) + } + return string(protocol) +} + +// lock should be held +func (b *NetworkMapBuilder) updateAccountLocked(account *Account) *Account { + if account.Network.CurrentSerial() > b.account.Network.CurrentSerial() { + b.account = account + } + return b.account } func (b *NetworkMapBuilder) GetPeerNetworkMap( @@ -951,25 +1048,27 @@ func (b *NetworkMapBuilder) GetPeerNetworkMap( validatedPeers map[string]struct{}, metrics *telemetry.AccountManagerMetrics, ) *NetworkMap { start := time.Now() - account := b.account.Load() + + b.cache.mu.RLock() + defer b.cache.mu.RUnlock() + + account := b.account peer := account.GetPeer(peerID) if peer == nil { return &NetworkMap{Network: account.Network.Copy()} } - b.cache.mu.RLock() - defer b.cache.mu.RUnlock() - aclView := b.cache.peerACLs[peerID] routesView := b.cache.peerRoutes[peerID] dnsConfig := b.cache.peerDNS[peerID] + sshView := b.cache.peerSSH[peerID] if aclView == nil || routesView == nil || dnsConfig == nil { return &NetworkMap{Network: account.Network.Copy()} } - nm := b.assembleNetworkMap(account, peer, aclView, routesView, dnsConfig, peersCustomZone, validatedPeers) + nm := b.assembleNetworkMap(account, peer, aclView, routesView, dnsConfig, sshView, peersCustomZone, validatedPeers) if metrics != nil { objectCount := int64(len(nm.Peers) + len(nm.OfflinePeers) + len(nm.Routes) + len(nm.FirewallRules) + len(nm.RoutesFirewallRules)) @@ -987,7 +1086,7 @@ func (b *NetworkMapBuilder) GetPeerNetworkMap( func (b *NetworkMapBuilder) assembleNetworkMap( account *Account, peer *nbpeer.Peer, aclView *PeerACLView, routesView *PeerRoutesView, - dnsConfig *nbdns.Config, customZone nbdns.CustomZone, validatedPeers map[string]struct{}, + dnsConfig *nbdns.Config, sshView *PeerSSHView, customZone nbdns.CustomZone, validatedPeers map[string]struct{}, ) *NetworkMap { var peersToConnect []*nbpeer.Peer @@ -1045,7 +1144,7 @@ func (b *NetworkMapBuilder) assembleNetworkMap( finalDNSConfig.CustomZones = zones } - return &NetworkMap{ + nm := &NetworkMap{ Peers: peersToConnect, Network: account.Network.Copy(), Routes: routes, @@ -1054,6 +1153,13 @@ func (b *NetworkMapBuilder) assembleNetworkMap( FirewallRules: firewallRules, RoutesFirewallRules: routesFirewallRules, } + + if sshView != nil { + nm.EnableSSH = sshView.EnableSSH + nm.AuthorizedUsers = sshView.AuthorizedUsers + } + + return nm } func (b *NetworkMapBuilder) GetPeerNetworkMapCompactCached( @@ -1061,64 +1167,27 @@ func (b *NetworkMapBuilder) GetPeerNetworkMapCompactCached( validatedPeers map[string]struct{}, metrics *telemetry.AccountManagerMetrics, ) *NetworkMap { start := time.Now() - account := b.account.Load() + + b.cache.mu.RLock() + defer b.cache.mu.RUnlock() + + account := b.account peer := account.GetPeer(peerID) if peer == nil { return &NetworkMap{Network: account.Network.Copy()} } - b.cache.mu.RLock() - defer b.cache.mu.RUnlock() - aclView := b.cache.peerACLs[peerID] routesView := b.cache.peerRoutes[peerID] dnsConfig := b.cache.peerDNS[peerID] + sshView := b.cache.peerSSH[peerID] if aclView == nil || routesView == nil || dnsConfig == nil { return &NetworkMap{Network: account.Network.Copy()} } - nm := b.assembleNetworkMapCompactCached(account, peer, aclView, routesView, dnsConfig, peersCustomZone, validatedPeers) - - if metrics != nil { - objectCount := int64(len(nm.Peers) + len(nm.OfflinePeers) + len(nm.Routes) + len(nm.FirewallRules) + len(nm.RoutesFirewallRules)) - metrics.CountNetworkMapObjects(objectCount) - metrics.CountGetPeerNetworkMapDuration(time.Since(start)) - - if objectCount > 5000 { - log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects from cache", - account.Id, objectCount) - } - } - - return nm -} - -func (b *NetworkMapBuilder) GetPeerNetworkMapCompact( - ctx context.Context, peerID string, peersCustomZone nbdns.CustomZone, - validatedPeers map[string]struct{}, metrics *telemetry.AccountManagerMetrics, -) *NetworkMap { - start := time.Now() - account := b.account.Load() - - peer := account.GetPeer(peerID) - if peer == nil { - return &NetworkMap{Network: account.Network.Copy()} - } - - b.cache.mu.RLock() - defer b.cache.mu.RUnlock() - - aclView := b.cache.peerACLs[peerID] - routesView := b.cache.peerRoutes[peerID] - dnsConfig := b.cache.peerDNS[peerID] - - if aclView == nil || routesView == nil || dnsConfig == nil { - return &NetworkMap{Network: account.Network.Copy()} - } - - nm := b.assembleNetworkMapCompact(account, peer, aclView, routesView, dnsConfig, peersCustomZone, validatedPeers) + nm := b.assembleNetworkMapCompactCached(account, peer, aclView, routesView, dnsConfig, sshView, peersCustomZone, validatedPeers) if metrics != nil { objectCount := int64(len(nm.Peers) + len(nm.OfflinePeers) + len(nm.Routes) + len(nm.FirewallRules) + len(nm.RoutesFirewallRules)) @@ -1136,7 +1205,7 @@ func (b *NetworkMapBuilder) GetPeerNetworkMapCompact( func (b *NetworkMapBuilder) assembleNetworkMapCompactCached( account *Account, peer *nbpeer.Peer, aclView *PeerACLView, routesView *PeerRoutesView, - dnsConfig *nbdns.Config, customZone nbdns.CustomZone, validatedPeers map[string]struct{}, + dnsConfig *nbdns.Config, sshView *PeerSSHView, customZone nbdns.CustomZone, validatedPeers map[string]struct{}, ) *NetworkMap { var peersToConnect []*nbpeer.Peer @@ -1182,7 +1251,7 @@ func (b *NetworkMapBuilder) assembleNetworkMapCompactCached( finalDNSConfig.CustomZones = zones } - return &NetworkMap{ + nm := &NetworkMap{ Peers: peersToConnect, Network: account.Network.Copy(), Routes: routes, @@ -1191,11 +1260,59 @@ func (b *NetworkMapBuilder) assembleNetworkMapCompactCached( FirewallRules: firewallRules, RoutesFirewallRules: routesFirewallRules, } + + if sshView != nil { + nm.EnableSSH = sshView.EnableSSH + nm.AuthorizedUsers = sshView.AuthorizedUsers + } + + return nm +} + +func (b *NetworkMapBuilder) GetPeerNetworkMapCompact( + ctx context.Context, peerID string, peersCustomZone nbdns.CustomZone, + validatedPeers map[string]struct{}, metrics *telemetry.AccountManagerMetrics, +) *NetworkMap { + start := time.Now() + + b.cache.mu.RLock() + defer b.cache.mu.RUnlock() + + account := b.account + + peer := account.GetPeer(peerID) + if peer == nil { + return &NetworkMap{Network: account.Network.Copy()} + } + + aclView := b.cache.peerACLs[peerID] + routesView := b.cache.peerRoutes[peerID] + dnsConfig := b.cache.peerDNS[peerID] + sshView := b.cache.peerSSH[peerID] + + if aclView == nil || routesView == nil || dnsConfig == nil { + return &NetworkMap{Network: account.Network.Copy()} + } + + nm := b.assembleNetworkMapCompact(account, peer, aclView, routesView, dnsConfig, sshView, peersCustomZone, validatedPeers) + + if metrics != nil { + objectCount := int64(len(nm.Peers) + len(nm.OfflinePeers) + len(nm.Routes) + len(nm.FirewallRules) + len(nm.RoutesFirewallRules)) + metrics.CountNetworkMapObjects(objectCount) + metrics.CountGetPeerNetworkMapDuration(time.Since(start)) + + if objectCount > 5000 { + log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects from cache", + account.Id, objectCount) + } + } + + return nm } func (b *NetworkMapBuilder) assembleNetworkMapCompact( account *Account, peer *nbpeer.Peer, aclView *PeerACLView, routesView *PeerRoutesView, - dnsConfig *nbdns.Config, customZone nbdns.CustomZone, validatedPeers map[string]struct{}, + dnsConfig *nbdns.Config, sshView *PeerSSHView, customZone nbdns.CustomZone, validatedPeers map[string]struct{}, ) *NetworkMap { var peersToConnect []*nbpeer.Peer @@ -1279,7 +1396,7 @@ func (b *NetworkMapBuilder) assembleNetworkMapCompact( finalDNSConfig.CustomZones = zones } - return &NetworkMap{ + nm := &NetworkMap{ Peers: peersToConnect, Network: account.Network.Copy(), Routes: routes, @@ -1288,14 +1405,13 @@ func (b *NetworkMapBuilder) assembleNetworkMapCompact( FirewallRules: firewallRules, RoutesFirewallRules: routesFirewallRules, } -} -func splitRouteAndPeer(r *route.Route) (string, string) { - parts := strings.Split(string(r.ID), ":") - if len(parts) < 2 { - return string(r.ID), "" + if sshView != nil { + nm.EnableSSH = sshView.EnableSSH + nm.AuthorizedUsers = sshView.AuthorizedUsers } - return parts[0], parts[1] + + return nm } func (b *NetworkMapBuilder) compactRoutesForPeer(peerID string, ownRouteIDs []route.ID, otherRouteIDs []route.ID) []*route.Route { @@ -1427,6 +1543,14 @@ func compactFirewallRules(expandedRules []*FirewallRule) []*FirewallRule { return compactRules } +func splitRouteAndPeer(r *route.Route) (string, string) { + parts := strings.Split(string(r.ID), ":") + if len(parts) < 2 { + return string(r.ID), "" + } + return parts[0], parts[1] +} + func (b *NetworkMapBuilder) generateFirewallRuleID(rule *FirewallRule) string { var s strings.Builder s.WriteString(fw) @@ -1501,6 +1625,106 @@ func (b *NetworkMapBuilder) isPeerRouter(account *Account, peerID string) bool { return false } +func (b *NetworkMapBuilder) incAddPeerLoop() { + for { + b.apb.mu.Lock() + if len(b.apb.ids) == 0 { + b.apb.sg.Wait() + } + b.addPeersIncrementally() + b.apb.mu.Unlock() + } +} + +// lock on b.apb level should be held +func (b *NetworkMapBuilder) addPeersIncrementally() { + peers := slices.Clone(b.apb.ids) + clear(b.apb.ids) + b.apb.ids = b.apb.ids[:0] + latestAcc := b.apb.la + b.apb.mu.Unlock() + + tt := time.Now() + b.cache.mu.Lock() + defer b.cache.mu.Unlock() + + account := b.updateAccountLocked(latestAcc) + + log.Debugf("NetworkMapBuilder: Starting incremental add of %d peers", len(peers)) + + allUpdates := make(map[string]*PeerUpdateDelta) + + for _, peerID := range peers { + peer := account.GetPeer(peerID) + if peer == nil { + b.apb.mu.Lock() + retries := b.apb.retryCount[peerID] + b.apb.mu.Unlock() + + if retries >= maxPeerAddRetries { + log.Errorf("NetworkMapBuilder: peer %s not found in account %s after %d retries, giving up", peerID, account.Id, retries) + b.apb.mu.Lock() + delete(b.apb.retryCount, peerID) + b.apb.mu.Unlock() + continue + } + + log.Warnf("NetworkMapBuilder: peer %s not found in account %s, retry %d/%d", peerID, account.Id, retries+1, maxPeerAddRetries) + b.apb.mu.Lock() + b.apb.retryCount[peerID] = retries + 1 + b.apb.mu.Unlock() + b.enqueuePeersForIncrementalAdd(latestAcc, peerID) + continue + } + + b.apb.mu.Lock() + delete(b.apb.retryCount, peerID) + b.apb.mu.Unlock() + + b.validatedPeers[peerID] = struct{}{} + b.cache.globalPeers[peerID] = peer + + peerGroups := b.updateIndexesForNewPeer(account, peerID) + b.buildPeerACLView(account, peerID) + b.buildPeerRoutesView(account, peerID) + b.buildPeerDNSView(account, peerID) + + peerDeltas := b.collectDeltasForNewPeer(account, peerID, peerGroups) + for affectedPeerID, delta := range peerDeltas { + if existing, ok := allUpdates[affectedPeerID]; ok { + existing.mergeFrom(delta) + continue + } + allUpdates[affectedPeerID] = delta + } + } + + for affectedPeerID, delta := range allUpdates { + b.applyDeltaToPeer(account, affectedPeerID, delta) + } + + log.Debugf("NetworkMapBuilder: Added %d peers to cache, affected %d peers, took %s", len(peers), len(allUpdates), time.Since(tt)) + + b.apb.mu.Lock() + if len(b.apb.ids) > 0 { + b.apb.sg.Signal() + } +} + +func (b *NetworkMapBuilder) enqueuePeersForIncrementalAdd(acc *Account, peerIDs ...string) { + b.apb.mu.Lock() + b.apb.ids = append(b.apb.ids, peerIDs...) + if b.apb.la != nil && acc.Network.CurrentSerial() > b.apb.la.Network.CurrentSerial() { + b.apb.la = acc + } + b.apb.sg.Signal() + b.apb.mu.Unlock() +} + +func (b *NetworkMapBuilder) EnqueuePeersForIncrementalAdd(acc *Account, peerIDs ...string) { + b.enqueuePeersForIncrementalAdd(acc, peerIDs...) +} + type ViewDelta struct { AddedPeerIDs []string RemovedPeerIDs []string @@ -1508,17 +1732,18 @@ type ViewDelta struct { RemovedRuleIDs []string } -func (b *NetworkMapBuilder) OnPeerAddedIncremental(peerID string) error { +func (b *NetworkMapBuilder) OnPeerAddedIncremental(acc *Account, peerID string) error { tt := time.Now() - account := b.account.Load() - peer := account.GetPeer(peerID) + peer := acc.GetPeer(peerID) if peer == nil { - return fmt.Errorf("peer %s not found in account", peerID) + return fmt.Errorf("NetworkMapBuilder: peer %s not found in account", peerID) } b.cache.mu.Lock() defer b.cache.mu.Unlock() + account := b.updateAccountLocked(acc) + log.Debugf("NetworkMapBuilder: Adding peer %s (IP: %s) to cache", peerID, peer.IP.String()) b.validatedPeers[peerID] = struct{}{} @@ -1577,6 +1802,13 @@ func (b *NetworkMapBuilder) updateIndexesForNewPeer(account *Account, peerID str } func (b *NetworkMapBuilder) incrementalUpdateAffectedPeers(account *Account, newPeerID string, peerGroups []string) { + updates := b.collectDeltasForNewPeer(account, newPeerID, peerGroups) + for affectedPeerID, delta := range updates { + b.applyDeltaToPeer(account, affectedPeerID, delta) + } +} + +func (b *NetworkMapBuilder) collectDeltasForNewPeer(account *Account, newPeerID string, peerGroups []string) map[string]*PeerUpdateDelta { updates := b.calculateIncrementalUpdates(account, newPeerID, peerGroups) if b.isPeerRouter(account, newPeerID) { @@ -1596,9 +1828,7 @@ func (b *NetworkMapBuilder) incrementalUpdateAffectedPeers(account *Account, new } } - for affectedPeerID, delta := range updates { - b.applyDeltaToPeer(account, affectedPeerID, delta) - } + return updates } func (b *NetworkMapBuilder) findPeersAffectedByNewRouter(account *Account, newRouterID string, routerGroups []string) map[string]struct{} { @@ -1792,8 +2022,8 @@ func (b *NetworkMapBuilder) calculateNewRouterNetworkResourceUpdates( updates[peerID] = delta } - if delta.AddConnectedPeer == "" { - delta.AddConnectedPeer = newPeerID + if !slices.Contains(delta.AddConnectedPeers, newPeerID) { + delta.AddConnectedPeers = append(delta.AddConnectedPeers, newPeerID) } delta.RebuildRoutesView = true @@ -1922,8 +2152,8 @@ func (b *NetworkMapBuilder) calculateNetworkResourceFirewallUpdates( updates[routerPeerID] = delta } - if delta.AddConnectedPeer == "" { - delta.AddConnectedPeer = newPeerID + if !slices.Contains(delta.AddConnectedPeers, newPeerID) { + delta.AddConnectedPeers = append(delta.AddConnectedPeers, newPeerID) } delta.RebuildRoutesView = true @@ -1933,13 +2163,63 @@ func (b *NetworkMapBuilder) calculateNetworkResourceFirewallUpdates( type PeerUpdateDelta struct { PeerID string - AddConnectedPeer string + AddConnectedPeers []string AddFirewallRules []*FirewallRuleDelta AddRoutes []route.ID UpdateRouteFirewallRules []*RouteFirewallRuleUpdate UpdateDNS bool RebuildRoutesView bool } + +func (d *PeerUpdateDelta) mergeFrom(other *PeerUpdateDelta) { + for _, peerID := range other.AddConnectedPeers { + if !slices.Contains(d.AddConnectedPeers, peerID) { + d.AddConnectedPeers = append(d.AddConnectedPeers, peerID) + } + } + + existingRuleIDs := make(map[string]struct{}, len(d.AddFirewallRules)) + for _, rule := range d.AddFirewallRules { + existingRuleIDs[rule.RuleID] = struct{}{} + } + for _, rule := range other.AddFirewallRules { + if _, exists := existingRuleIDs[rule.RuleID]; !exists { + d.AddFirewallRules = append(d.AddFirewallRules, rule) + existingRuleIDs[rule.RuleID] = struct{}{} + } + } + + for _, routeID := range other.AddRoutes { + if !slices.Contains(d.AddRoutes, routeID) { + d.AddRoutes = append(d.AddRoutes, routeID) + } + } + + existingRouteUpdates := make(map[string]map[string]struct{}) + for _, update := range d.UpdateRouteFirewallRules { + if existingRouteUpdates[update.RuleID] == nil { + existingRouteUpdates[update.RuleID] = make(map[string]struct{}) + } + existingRouteUpdates[update.RuleID][update.AddSourceIP] = struct{}{} + } + for _, update := range other.UpdateRouteFirewallRules { + if existingRouteUpdates[update.RuleID] == nil { + existingRouteUpdates[update.RuleID] = make(map[string]struct{}) + } + if _, exists := existingRouteUpdates[update.RuleID][update.AddSourceIP]; !exists { + d.UpdateRouteFirewallRules = append(d.UpdateRouteFirewallRules, update) + existingRouteUpdates[update.RuleID][update.AddSourceIP] = struct{}{} + } + } + + if other.UpdateDNS { + d.UpdateDNS = true + } + if other.RebuildRoutesView { + d.RebuildRoutesView = true + } +} + type FirewallRuleDelta struct { Rule *FirewallRule RuleID string @@ -1977,7 +2257,7 @@ func (b *NetworkMapBuilder) addUpdateForPeersInGroups( PeerIP: newPeer.IP.String(), Direction: direction, Action: string(rule.Action), - Protocol: string(rule.Protocol), + Protocol: firewallRuleProtocol(rule.Protocol), } for _, peerID := range peers { if peerID == newPeerID { @@ -2028,7 +2308,7 @@ func (b *NetworkMapBuilder) addUpdateForDirectPeerResource( PeerIP: newPeer.IP.String(), Direction: direction, Action: string(rule.Action), - Protocol: string(rule.Protocol), + Protocol: firewallRuleProtocol(rule.Protocol), } b.addOrUpdateFirewallRuleInDelta(updates, targetPeerID, newPeerID, rule, direction, fr, fr.PeerIP, targetPeer) @@ -2041,11 +2321,13 @@ func (b *NetworkMapBuilder) addOrUpdateFirewallRuleInDelta( delta := updates[targetPeerID] if delta == nil { delta = &PeerUpdateDelta{ - PeerID: targetPeerID, - AddConnectedPeer: newPeerID, - AddFirewallRules: make([]*FirewallRuleDelta, 0), + PeerID: targetPeerID, + AddConnectedPeers: []string{newPeerID}, + AddFirewallRules: make([]*FirewallRuleDelta, 0), } updates[targetPeerID] = delta + } else if !slices.Contains(delta.AddConnectedPeers, newPeerID) { + delta.AddConnectedPeers = append(delta.AddConnectedPeers, newPeerID) } baseRule.PeerIP = peerIP @@ -2071,10 +2353,12 @@ func (b *NetworkMapBuilder) addOrUpdateFirewallRuleInDelta( } func (b *NetworkMapBuilder) applyDeltaToPeer(account *Account, peerID string, delta *PeerUpdateDelta) { - if delta.AddConnectedPeer != "" || len(delta.AddFirewallRules) > 0 { + if len(delta.AddConnectedPeers) > 0 || len(delta.AddFirewallRules) > 0 { if aclView := b.cache.peerACLs[peerID]; aclView != nil { - if delta.AddConnectedPeer != "" && !slices.Contains(aclView.ConnectedPeerIDs, delta.AddConnectedPeer) { - aclView.ConnectedPeerIDs = append(aclView.ConnectedPeerIDs, delta.AddConnectedPeer) + for _, connectedPeerID := range delta.AddConnectedPeers { + if !slices.Contains(aclView.ConnectedPeerIDs, connectedPeerID) { + aclView.ConnectedPeerIDs = append(aclView.ConnectedPeerIDs, connectedPeerID) + } } for _, ruleDelta := range delta.AddFirewallRules { @@ -2130,11 +2414,11 @@ func (b *NetworkMapBuilder) updateRouteFirewallRules(routesView *PeerRoutesView, } } -func (b *NetworkMapBuilder) OnPeerDeleted(peerID string) error { +func (b *NetworkMapBuilder) OnPeerDeleted(acc *Account, peerID string) error { b.cache.mu.Lock() defer b.cache.mu.Unlock() - account := b.account.Load() + account := b.updateAccountLocked(acc) deletedPeer := b.cache.globalPeers[peerID] if deletedPeer == nil { @@ -2190,6 +2474,7 @@ func (b *NetworkMapBuilder) OnPeerDeleted(peerID string) error { delete(b.cache.peerACLs, peerID) delete(b.cache.peerRoutes, peerID) delete(b.cache.peerDNS, peerID) + delete(b.cache.peerSSH, peerID) delete(b.cache.globalPeers, peerID) @@ -2240,11 +2525,16 @@ func (b *NetworkMapBuilder) OnPeerDeleted(peerID string) error { b.buildPeerRoutesView(account, affectedPeerID) } - peerDeletionUpdates := b.findPeersAffectedByDeletedPeerACL(peerID, peerIP) + peersToRebuildACL := make(map[string]struct{}) + peerDeletionUpdates := b.findPeersAffectedByDeletedPeerACL(peerID, peerIP, peerGroups, peersToRebuildACL) for affectedPeerID, updates := range peerDeletionUpdates { b.applyDeletionUpdates(affectedPeerID, updates) } + for affectedPeerID := range peersToRebuildACL { + b.buildPeerACLView(account, affectedPeerID) + } + b.cleanupUnusedRules() log.Debugf("NetworkMapBuilder: Deleted peer %s, affected %d other peers", peerID, len(affectedPeers)) @@ -2255,6 +2545,8 @@ func (b *NetworkMapBuilder) OnPeerDeleted(peerID string) error { func (b *NetworkMapBuilder) findPeersAffectedByDeletedPeerACL( deletedPeerID string, peerIP string, + peerGroups []string, + peersToRebuildACL map[string]struct{}, ) map[string]*PeerDeletionUpdate { affected := make(map[string]*PeerDeletionUpdate) @@ -2264,26 +2556,47 @@ func (b *NetworkMapBuilder) findPeersAffectedByDeletedPeerACL( continue } - if !slices.Contains(aclView.ConnectedPeerIDs, deletedPeerID) { - continue - } - if affected[peerID] == nil { - affected[peerID] = &PeerDeletionUpdate{ - RemovePeerID: deletedPeerID, - PeerIP: peerIP, + if slices.Contains(aclView.ConnectedPeerIDs, deletedPeerID) { + peersToRebuildACL[peerID] = struct{}{} + if affected[peerID] == nil { + affected[peerID] = &PeerDeletionUpdate{ + RemovePeerID: deletedPeerID, + PeerIP: peerIP, + } } } + } - for _, ruleID := range aclView.FirewallRuleIDs { - if rule := b.cache.globalRules[ruleID]; rule != nil && rule.PeerIP == peerIP { - affected[peerID].RemoveFirewallRuleIDs = append( - affected[peerID].RemoveFirewallRuleIDs, - ruleID, - ) + affectedRouteOwners := make(map[string]struct{}) + + for _, groupID := range peerGroups { + if routeMap, ok := b.cache.acgToRoutes[groupID]; ok { + for _, info := range routeMap { + if info.PeerID != deletedPeerID { + affectedRouteOwners[info.PeerID] = struct{}{} + } } } } + for _, info := range b.cache.noACGRoutes { + if info.PeerID != deletedPeerID { + affectedRouteOwners[info.PeerID] = struct{}{} + } + } + + for ownerPeerID := range affectedRouteOwners { + if affected[ownerPeerID] == nil { + affected[ownerPeerID] = &PeerDeletionUpdate{ + RemovePeerID: deletedPeerID, + PeerIP: peerIP, + RemoveFromSourceRanges: true, + } + } else { + affected[ownerPeerID].RemoveFromSourceRanges = true + } + } + return affected } @@ -2296,18 +2609,6 @@ type PeerDeletionUpdate struct { } func (b *NetworkMapBuilder) applyDeletionUpdates(peerID string, updates *PeerDeletionUpdate) { - if aclView := b.cache.peerACLs[peerID]; aclView != nil { - aclView.ConnectedPeerIDs = slices.DeleteFunc(aclView.ConnectedPeerIDs, func(id string) bool { - return id == updates.RemovePeerID - }) - - if len(updates.RemoveFirewallRuleIDs) > 0 { - aclView.FirewallRuleIDs = slices.DeleteFunc(aclView.FirewallRuleIDs, func(ruleID string) bool { - return slices.Contains(updates.RemoveFirewallRuleIDs, ruleID) - }) - } - } - if routesView := b.cache.peerRoutes[peerID]; routesView != nil { if len(updates.RemoveRouteIDs) > 0 { routesView.NetworkResourceIDs = slices.DeleteFunc(routesView.NetworkResourceIDs, func(routeID route.ID) bool { diff --git a/management/server/types/policy.go b/management/server/types/policy.go index 5e86a87c6..d4e1a8816 100644 --- a/management/server/types/policy.go +++ b/management/server/types/policy.go @@ -23,6 +23,8 @@ const ( PolicyRuleProtocolUDP = PolicyRuleProtocolType("udp") // PolicyRuleProtocolICMP type of traffic PolicyRuleProtocolICMP = PolicyRuleProtocolType("icmp") + // PolicyRuleProtocolNetbirdSSH type of traffic + PolicyRuleProtocolNetbirdSSH = PolicyRuleProtocolType("netbird-ssh") ) const ( @@ -167,6 +169,8 @@ func ParseRuleString(rule string) (PolicyRuleProtocolType, RulePortRange, error) protocol = PolicyRuleProtocolUDP case "icmp": return "", RulePortRange{}, errors.New("icmp does not accept ports; use 'icmp' without '/…'") + case "netbird-ssh": + return PolicyRuleProtocolNetbirdSSH, RulePortRange{Start: nativeSSHPortNumber, End: nativeSSHPortNumber}, nil default: return "", RulePortRange{}, fmt.Errorf("invalid protocol: %q", protoStr) } diff --git a/management/server/types/policyrule.go b/management/server/types/policyrule.go index 2643ae45c..bb75dd555 100644 --- a/management/server/types/policyrule.go +++ b/management/server/types/policyrule.go @@ -80,6 +80,12 @@ type PolicyRule struct { // PortRanges a list of port ranges. PortRanges []RulePortRange `gorm:"serializer:json"` + + // AuthorizedGroups is a map of groupIDs and their respective access to local users via ssh + AuthorizedGroups map[string][]string `gorm:"serializer:json"` + + // AuthorizedUser is a list of userIDs that are authorized to access local resources via ssh + AuthorizedUser string } // Copy returns a copy of a policy rule @@ -99,10 +105,16 @@ func (pm *PolicyRule) Copy() *PolicyRule { Protocol: pm.Protocol, Ports: make([]string, len(pm.Ports)), PortRanges: make([]RulePortRange, len(pm.PortRanges)), + AuthorizedGroups: make(map[string][]string, len(pm.AuthorizedGroups)), + AuthorizedUser: pm.AuthorizedUser, } copy(rule.Destinations, pm.Destinations) copy(rule.Sources, pm.Sources) copy(rule.Ports, pm.Ports) copy(rule.PortRanges, pm.PortRanges) + for k, v := range pm.AuthorizedGroups { + rule.AuthorizedGroups[k] = make([]string, len(v)) + copy(rule.AuthorizedGroups[k], v) + } return rule } diff --git a/management/server/types/settings.go b/management/server/types/settings.go index b4afb2f5e..867e12bef 100644 --- a/management/server/types/settings.go +++ b/management/server/types/settings.go @@ -52,6 +52,9 @@ type Settings struct { // LazyConnectionEnabled indicates if the experimental feature is enabled or disabled LazyConnectionEnabled bool `gorm:"default:false"` + + // AutoUpdateVersion client auto-update version + AutoUpdateVersion string `gorm:"default:'disabled'"` } // Copy copies the Settings struct @@ -72,6 +75,7 @@ func (s *Settings) Copy() *Settings { LazyConnectionEnabled: s.LazyConnectionEnabled, DNSDomain: s.DNSDomain, NetworkRange: s.NetworkRange, + AutoUpdateVersion: s.AutoUpdateVersion, } if s.Extra != nil { settings.Extra = s.Extra.Copy() diff --git a/management/server/types/user.go b/management/server/types/user.go index beb3586df..dc601e15b 100644 --- a/management/server/types/user.go +++ b/management/server/types/user.go @@ -7,6 +7,7 @@ import ( "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integration_reference" + "github.com/netbirdio/netbird/util/crypt" ) const ( @@ -65,7 +66,11 @@ type UserInfo struct { LastLogin time.Time `json:"last_login"` Issued string `json:"issued"` PendingApproval bool `json:"pending_approval"` + Password string `json:"password"` IntegrationReference integration_reference.IntegrationReference `json:"-"` + // IdPID is the identity provider ID (connector ID) extracted from the Dex-encoded user ID. + // This field is only populated when the user ID can be decoded from Dex's format. + IdPID string `json:"idp_id,omitempty"` } // User represents a user of the system @@ -96,6 +101,9 @@ type User struct { Issued string `gorm:"default:api"` IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` + + Name string `gorm:"default:''"` + Email string `gorm:"default:''"` } // IsBlocked returns true if the user is blocked, false otherwise @@ -143,10 +151,16 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) { } if userData == nil { + + name := u.Name + if u.IsServiceUser { + name = u.ServiceUserName + } + return &UserInfo{ ID: u.Id, - Email: "", - Name: u.ServiceUserName, + Email: u.Email, + Name: name, Role: string(u.Role), AutoGroups: u.AutoGroups, Status: string(UserStatusActive), @@ -178,6 +192,7 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) { LastLogin: u.GetLastLogin(), Issued: u.Issued, PendingApproval: u.PendingApproval, + Password: userData.Password, }, nil } @@ -204,11 +219,13 @@ func (u *User) Copy() *User { CreatedAt: u.CreatedAt, Issued: u.Issued, IntegrationReference: u.IntegrationReference, + Email: u.Email, + Name: u.Name, } } // NewUser creates a new user -func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string) *User { +func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string, email string, name string) *User { return &User{ Id: id, Role: role, @@ -218,20 +235,70 @@ func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, se AutoGroups: autoGroups, Issued: issued, CreatedAt: time.Now().UTC(), + Name: name, + Email: email, } } // NewRegularUser creates a new user with role UserRoleUser -func NewRegularUser(id string) *User { - return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI) +func NewRegularUser(id, email, name string) *User { + return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI, email, name) } // NewAdminUser creates a new user with role UserRoleAdmin func NewAdminUser(id string) *User { - return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI) + return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI, "", "") } // NewOwnerUser creates a new user with role UserRoleOwner -func NewOwnerUser(id string) *User { - return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI) +func NewOwnerUser(id string, email string, name string) *User { + return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI, email, name) +} + +// EncryptSensitiveData encrypts the user's sensitive fields (Email and Name) in place. +func (u *User) EncryptSensitiveData(enc *crypt.FieldEncrypt) error { + if enc == nil { + return nil + } + + var err error + if u.Email != "" { + u.Email, err = enc.Encrypt(u.Email) + if err != nil { + return fmt.Errorf("encrypt email: %w", err) + } + } + + if u.Name != "" { + u.Name, err = enc.Encrypt(u.Name) + if err != nil { + return fmt.Errorf("encrypt name: %w", err) + } + } + + return nil +} + +// DecryptSensitiveData decrypts the user's sensitive fields (Email and Name) in place. +func (u *User) DecryptSensitiveData(enc *crypt.FieldEncrypt) error { + if enc == nil { + return nil + } + + var err error + if u.Email != "" { + u.Email, err = enc.Decrypt(u.Email) + if err != nil { + return fmt.Errorf("decrypt email: %w", err) + } + } + + if u.Name != "" { + u.Name, err = enc.Decrypt(u.Name) + if err != nil { + return fmt.Errorf("decrypt name: %w", err) + } + } + + return nil } diff --git a/management/server/types/user_test.go b/management/server/types/user_test.go new file mode 100644 index 000000000..e11df96aa --- /dev/null +++ b/management/server/types/user_test.go @@ -0,0 +1,298 @@ +package types + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/util/crypt" +) + +func TestUser_EncryptSensitiveData(t *testing.T) { + key, err := crypt.GenerateKey() + require.NoError(t, err) + + fieldEncrypt, err := crypt.NewFieldEncrypt(key) + require.NoError(t, err) + + t.Run("encrypt email and name", func(t *testing.T) { + user := &User{ + Id: "user-1", + Email: "test@example.com", + Name: "Test User", + } + + err := user.EncryptSensitiveData(fieldEncrypt) + require.NoError(t, err) + + assert.NotEqual(t, "test@example.com", user.Email, "email should be encrypted") + assert.NotEqual(t, "Test User", user.Name, "name should be encrypted") + assert.NotEmpty(t, user.Email, "encrypted email should not be empty") + assert.NotEmpty(t, user.Name, "encrypted name should not be empty") + }) + + t.Run("encrypt empty email and name", func(t *testing.T) { + user := &User{ + Id: "user-2", + Email: "", + Name: "", + } + + err := user.EncryptSensitiveData(fieldEncrypt) + require.NoError(t, err) + + assert.Equal(t, "", user.Email, "empty email should remain empty") + assert.Equal(t, "", user.Name, "empty name should remain empty") + }) + + t.Run("encrypt only email", func(t *testing.T) { + user := &User{ + Id: "user-3", + Email: "test@example.com", + Name: "", + } + + err := user.EncryptSensitiveData(fieldEncrypt) + require.NoError(t, err) + + assert.NotEqual(t, "test@example.com", user.Email, "email should be encrypted") + assert.NotEmpty(t, user.Email, "encrypted email should not be empty") + assert.Equal(t, "", user.Name, "empty name should remain empty") + }) + + t.Run("encrypt only name", func(t *testing.T) { + user := &User{ + Id: "user-4", + Email: "", + Name: "Test User", + } + + err := user.EncryptSensitiveData(fieldEncrypt) + require.NoError(t, err) + + assert.Equal(t, "", user.Email, "empty email should remain empty") + assert.NotEqual(t, "Test User", user.Name, "name should be encrypted") + assert.NotEmpty(t, user.Name, "encrypted name should not be empty") + }) + + t.Run("nil encryptor returns no error", func(t *testing.T) { + user := &User{ + Id: "user-5", + Email: "test@example.com", + Name: "Test User", + } + + err := user.EncryptSensitiveData(nil) + require.NoError(t, err) + + assert.Equal(t, "test@example.com", user.Email, "email should remain unchanged with nil encryptor") + assert.Equal(t, "Test User", user.Name, "name should remain unchanged with nil encryptor") + }) +} + +func TestUser_DecryptSensitiveData(t *testing.T) { + key, err := crypt.GenerateKey() + require.NoError(t, err) + + fieldEncrypt, err := crypt.NewFieldEncrypt(key) + require.NoError(t, err) + + t.Run("decrypt email and name", func(t *testing.T) { + originalEmail := "test@example.com" + originalName := "Test User" + + user := &User{ + Id: "user-1", + Email: originalEmail, + Name: originalName, + } + + err := user.EncryptSensitiveData(fieldEncrypt) + require.NoError(t, err) + + err = user.DecryptSensitiveData(fieldEncrypt) + require.NoError(t, err) + + assert.Equal(t, originalEmail, user.Email, "decrypted email should match original") + assert.Equal(t, originalName, user.Name, "decrypted name should match original") + }) + + t.Run("decrypt empty email and name", func(t *testing.T) { + user := &User{ + Id: "user-2", + Email: "", + Name: "", + } + + err := user.DecryptSensitiveData(fieldEncrypt) + require.NoError(t, err) + + assert.Equal(t, "", user.Email, "empty email should remain empty") + assert.Equal(t, "", user.Name, "empty name should remain empty") + }) + + t.Run("decrypt only email", func(t *testing.T) { + originalEmail := "test@example.com" + + user := &User{ + Id: "user-3", + Email: originalEmail, + Name: "", + } + + err := user.EncryptSensitiveData(fieldEncrypt) + require.NoError(t, err) + + err = user.DecryptSensitiveData(fieldEncrypt) + require.NoError(t, err) + + assert.Equal(t, originalEmail, user.Email, "decrypted email should match original") + assert.Equal(t, "", user.Name, "empty name should remain empty") + }) + + t.Run("decrypt only name", func(t *testing.T) { + originalName := "Test User" + + user := &User{ + Id: "user-4", + Email: "", + Name: originalName, + } + + err := user.EncryptSensitiveData(fieldEncrypt) + require.NoError(t, err) + + err = user.DecryptSensitiveData(fieldEncrypt) + require.NoError(t, err) + + assert.Equal(t, "", user.Email, "empty email should remain empty") + assert.Equal(t, originalName, user.Name, "decrypted name should match original") + }) + + t.Run("nil encryptor returns no error", func(t *testing.T) { + user := &User{ + Id: "user-5", + Email: "test@example.com", + Name: "Test User", + } + + err := user.DecryptSensitiveData(nil) + require.NoError(t, err) + + assert.Equal(t, "test@example.com", user.Email, "email should remain unchanged with nil encryptor") + assert.Equal(t, "Test User", user.Name, "name should remain unchanged with nil encryptor") + }) + + t.Run("decrypt with invalid ciphertext returns error", func(t *testing.T) { + user := &User{ + Id: "user-6", + Email: "not-valid-base64-ciphertext!!!", + Name: "Test User", + } + + err := user.DecryptSensitiveData(fieldEncrypt) + require.Error(t, err) + assert.Contains(t, err.Error(), "decrypt email") + }) + + t.Run("decrypt with wrong key returns error", func(t *testing.T) { + originalEmail := "test@example.com" + originalName := "Test User" + + user := &User{ + Id: "user-7", + Email: originalEmail, + Name: originalName, + } + + err := user.EncryptSensitiveData(fieldEncrypt) + require.NoError(t, err) + + differentKey, err := crypt.GenerateKey() + require.NoError(t, err) + + differentEncrypt, err := crypt.NewFieldEncrypt(differentKey) + require.NoError(t, err) + + err = user.DecryptSensitiveData(differentEncrypt) + require.Error(t, err) + assert.Contains(t, err.Error(), "decrypt email") + }) +} + +func TestUser_EncryptDecryptRoundTrip(t *testing.T) { + key, err := crypt.GenerateKey() + require.NoError(t, err) + + fieldEncrypt, err := crypt.NewFieldEncrypt(key) + require.NoError(t, err) + + testCases := []struct { + name string + email string + uname string + }{ + { + name: "standard email and name", + email: "user@example.com", + uname: "John Doe", + }, + { + name: "email with special characters", + email: "user+tag@sub.example.com", + uname: "O'Brien, Mary-Jane", + }, + { + name: "unicode characters", + email: "user@example.com", + uname: "Jean-Pierre Müller 日本語", + }, + { + name: "long values", + email: "very.long.email.address.that.is.quite.extended@subdomain.example.organization.com", + uname: "A Very Long Name That Contains Many Words And Is Quite Extended For Testing Purposes", + }, + { + name: "empty email only", + email: "", + uname: "Name Only", + }, + { + name: "empty name only", + email: "email@only.com", + uname: "", + }, + { + name: "both empty", + email: "", + uname: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + user := &User{ + Id: "test-user", + Email: tc.email, + Name: tc.uname, + } + + err := user.EncryptSensitiveData(fieldEncrypt) + require.NoError(t, err) + + if tc.email != "" { + assert.NotEqual(t, tc.email, user.Email, "email should be encrypted") + } + if tc.uname != "" { + assert.NotEqual(t, tc.uname, user.Name, "name should be encrypted") + } + + err = user.DecryptSensitiveData(fieldEncrypt) + require.NoError(t, err) + + assert.Equal(t, tc.email, user.Email, "decrypted email should match original") + assert.Equal(t, tc.uname, user.Name, "decrypted name should match original") + }) + } +} diff --git a/management/server/user.go b/management/server/user.go index ca02f91e6..4f9007b61 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -13,6 +13,7 @@ import ( "github.com/google/uuid" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/idp/dex" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/idp" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -40,7 +41,7 @@ func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountI } newUserID := uuid.New().String() - newUser := types.NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, types.UserIssuedAPI) + newUser := types.NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, types.UserIssuedAPI, "", "") newUser.AccountID = accountID log.WithContext(ctx).Debugf("New User: %v", newUser) @@ -104,7 +105,12 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u inviterID = createdBy } - idpUser, err := am.createNewIdpUser(ctx, accountID, inviterID, invite) + var idpUser *idp.UserData + if IsEmbeddedIdp(am.idpManager) { + idpUser, err = am.createEmbeddedIdpUser(ctx, accountID, inviterID, invite) + } else { + idpUser, err = am.createNewIdpUser(ctx, accountID, inviterID, invite) + } if err != nil { return nil, err } @@ -117,18 +123,26 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u Issued: invite.Issued, IntegrationReference: invite.IntegrationReference, CreatedAt: time.Now().UTC(), + Email: invite.Email, + Name: invite.Name, } if err = am.Store.SaveUser(ctx, newUser); err != nil { return nil, err } - _, err = am.refreshCache(ctx, accountID) - if err != nil { - return nil, err + if !IsEmbeddedIdp(am.idpManager) { + _, err = am.refreshCache(ctx, accountID) + if err != nil { + return nil, err + } } - am.StoreEvent(ctx, userID, newUser.Id, accountID, activity.UserInvited, nil) + eventType := activity.UserInvited + if IsEmbeddedIdp(am.idpManager) { + eventType = activity.UserCreated + } + am.StoreEvent(ctx, userID, newUser.Id, accountID, eventType, nil) return newUser.ToUserInfo(idpUser) } @@ -172,6 +186,34 @@ func (am *DefaultAccountManager) createNewIdpUser(ctx context.Context, accountID return am.idpManager.CreateUser(ctx, invite.Email, invite.Name, accountID, inviterUser.Email) } +// createEmbeddedIdpUser validates the invite and creates a new user in the embedded IdP. +// Unlike createNewIdpUser, this method fetches user data directly from the database +// since the embedded IdP usage ensures the username and email are stored locally in the User table. +func (am *DefaultAccountManager) createEmbeddedIdpUser(ctx context.Context, accountID string, inviterID string, invite *types.UserInfo) (*idp.UserData, error) { + inviter, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, inviterID) + if err != nil { + return nil, fmt.Errorf("failed to get inviter user: %w", err) + } + + if inviter == nil { + return nil, status.Errorf(status.NotFound, "inviter user with ID %s doesn't exist", inviterID) + } + + // check if the user is already registered with this email => reject + existingUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return nil, err + } + + for _, user := range existingUsers { + if strings.EqualFold(user.Email, invite.Email) { + return nil, status.Errorf(status.UserAlreadyExists, "can't invite a user with an existing NetBird account") + } + } + + return am.idpManager.CreateUser(ctx, invite.Email, invite.Name, accountID, inviter.Email) +} + func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) { return am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, id) } @@ -523,16 +565,14 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - userHadPeers, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate( + _, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate( ctx, transaction, groupsMap, accountID, initiatorUserID, initiatorUser, update, addIfNotExists, settings, ) if err != nil { return fmt.Errorf("failed to process update for user %s: %w", update.Id, err) } - if userHadPeers { - updateAccountPeers = true - } + updateAccountPeers = true err = transaction.SaveUser(ctx, updatedUser) if err != nil { @@ -579,9 +619,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, log.WithContext(ctx).Errorf("failed update expired peers: %s", err) return nil, err } - } - - if settings.GroupsPropagationEnabled && updateAccountPeers { + } else if updateAccountPeers { if err = am.Store.IncrementNetworkSerial(ctx, accountID); err != nil { return nil, fmt.Errorf("failed to increment network serial: %w", err) } @@ -761,7 +799,7 @@ func handleOwnerRoleTransfer(ctx context.Context, transaction store.Store, initi // If the AccountManager has a non-nil idpManager and the User is not a service user, // it will attempt to look up the UserData from the cache. func (am *DefaultAccountManager) getUserInfo(ctx context.Context, user *types.User, accountID string) (*types.UserInfo, error) { - if !isNil(am.idpManager) && !user.IsServiceUser { + if !isNil(am.idpManager) && !user.IsServiceUser && !IsEmbeddedIdp(am.idpManager) { userData, err := am.lookupUserInCache(ctx, user.Id, accountID) if err != nil { return nil, err @@ -812,7 +850,10 @@ func validateUserUpdate(groupsMap map[string]*types.Group, initiatorUser, oldUse } // GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist -func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, userID, domain string) (*types.Account, error) { +func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, userAuth auth.UserAuth) (*types.Account, error) { + userID := userAuth.UserId + domain := userAuth.Domain + start := time.Now() unlock := am.Store.AcquireGlobalLock(ctx) defer unlock() @@ -823,7 +864,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, u account, err := am.Store.GetAccountByUser(ctx, userID) if err != nil { if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { - account, err = am.newAccount(ctx, userID, lowerDomain) + account, err = am.newAccount(ctx, userID, lowerDomain, userAuth.Email, userAuth.Name) if err != nil { return nil, err } @@ -888,7 +929,8 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a var queriedUsers []*idp.UserData var err error - if !isNil(am.idpManager) { + // embedded IdP ensures that we have user data (email and name) stored in the database. + if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) { users := make(map[string]userLoggedInOnce, len(accountUsers)) usersFromIntegration := make([]*idp.UserData, 0) for _, user := range accountUsers { @@ -925,6 +967,10 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a if err != nil { return nil, err } + // Try to decode Dex user ID to extract the IdP ID (connector ID) + if _, connectorID, decodeErr := dex.DecodeDexUserID(accountUser.Id); decodeErr == nil && connectorID != "" { + info.IdPID = connectorID + } userInfosMap[accountUser.Id] = info } @@ -946,7 +992,7 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a info = &types.UserInfo{ ID: localUser.Id, - Email: "", + Email: localUser.Email, Name: name, Role: string(localUser.Role), AutoGroups: localUser.AutoGroups, @@ -955,6 +1001,10 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a NonDeletable: localUser.NonDeletable, } } + // Try to decode Dex user ID to extract the IdP ID (connector ID) + if _, connectorID, decodeErr := dex.DecodeDexUserID(localUser.Id); decodeErr == nil && connectorID != "" { + info.IdPID = connectorID + } userInfosMap[info.ID] = info } @@ -996,6 +1046,12 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou ) } + if len(peerIDs) != 0 { + if err := am.Store.IncrementNetworkSerial(ctx, accountID); err != nil { + return err + } + } + err = am.networkMapController.OnPeersUpdated(ctx, accountID, peerIDs) if err != nil { return fmt.Errorf("notify network map controller of peer update: %w", err) @@ -1111,6 +1167,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI var updateAccountPeers bool var userPeers []*nbpeer.Peer var targetUser *types.User + var settings *types.Settings var err error err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { @@ -1119,6 +1176,11 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI return fmt.Errorf("failed to get user to delete: %w", err) } + settings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return fmt.Errorf("failed to get account settings: %w", err) + } + userPeers, err = transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, targetUserInfo.ID) if err != nil { return fmt.Errorf("failed to get user peers: %w", err) @@ -1126,7 +1188,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI if len(userPeers) > 0 { updateAccountPeers = true - addPeerRemovedEvents, err = deletePeers(ctx, am, transaction, accountID, targetUserInfo.ID, userPeers) + addPeerRemovedEvents, err = deletePeers(ctx, am, transaction, accountID, targetUserInfo.ID, userPeers, settings) if err != nil { return fmt.Errorf("failed to delete user peers: %w", err) } @@ -1145,6 +1207,9 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI var peerIDs []string for _, peer := range userPeers { peerIDs = append(peerIDs, peer.ID) + if err = am.integratedPeerValidator.PeerDeleted(ctx, accountID, peer.ID, settings.Extra); err != nil { + log.WithContext(ctx).Errorf("failed to delete peer %s from integrated validator: %v", peer.ID, err) + } } if err := am.networkMapController.OnPeersDeleted(ctx, accountID, peerIDs); err != nil { log.WithContext(ctx).Errorf("failed to delete peers %s from network map: %v", peerIDs, err) @@ -1153,6 +1218,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI for _, addPeerRemovedEvent := range addPeerRemovedEvents { addPeerRemovedEvent() } + meta := map[string]any{"name": targetUserInfo.Name, "email": targetUserInfo.Email, "created_at": targetUser.CreatedAt} am.StoreEvent(ctx, initiatorUserID, targetUser.Id, accountID, activity.UserDeleted, meta) diff --git a/management/server/user_test.go b/management/server/user_test.go index 0d778cfa2..6d356a8b1 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -3,6 +3,7 @@ package server import ( "context" "fmt" + "os" "reflect" "testing" "time" @@ -29,6 +30,7 @@ import ( "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/idp/dex" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integration_reference" @@ -58,7 +60,7 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) { } t.Cleanup(cleanup) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false) err = s.SaveAccount(context.Background(), account) if err != nil { @@ -105,7 +107,7 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) { } t.Cleanup(cleanup) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false) account.Users[mockTargetUserId] = &types.User{ Id: mockTargetUserId, IsServiceUser: false, @@ -133,7 +135,7 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) { } t.Cleanup(cleanup) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false) account.Users[mockTargetUserId] = &types.User{ Id: mockTargetUserId, IsServiceUser: true, @@ -165,7 +167,7 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) { } t.Cleanup(cleanup) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false) err = store.SaveAccount(context.Background(), account) if err != nil { @@ -190,7 +192,7 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) { } t.Cleanup(cleanup) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false) err = store.SaveAccount(context.Background(), account) if err != nil { @@ -215,7 +217,7 @@ func TestUser_DeletePAT(t *testing.T) { } t.Cleanup(cleanup) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false) account.Users[mockUserID] = &types.User{ Id: mockUserID, PATs: map[string]*types.PersonalAccessToken{ @@ -258,7 +260,7 @@ func TestUser_GetPAT(t *testing.T) { } t.Cleanup(cleanup) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false) account.Users[mockUserID] = &types.User{ Id: mockUserID, AccountID: mockAccountID, @@ -298,7 +300,7 @@ func TestUser_GetAllPATs(t *testing.T) { } t.Cleanup(cleanup) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false) account.Users[mockUserID] = &types.User{ Id: mockUserID, AccountID: mockAccountID, @@ -362,6 +364,8 @@ func TestUser_Copy(t *testing.T) { ID: 0, IntegrationType: "test", }, + Email: "whatever@gmail.com", + Name: "John Doe", } err := validateStruct(user) @@ -408,7 +412,7 @@ func TestUser_CreateServiceUser(t *testing.T) { } t.Cleanup(cleanup) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false) err = store.SaveAccount(context.Background(), account) if err != nil { @@ -455,7 +459,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) { } t.Cleanup(cleanup) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false) err = store.SaveAccount(context.Background(), account) if err != nil { @@ -503,7 +507,7 @@ func TestUser_CreateUser_RegularUser(t *testing.T) { } t.Cleanup(cleanup) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false) err = store.SaveAccount(context.Background(), account) if err != nil { @@ -534,7 +538,7 @@ func TestUser_InviteNewUser(t *testing.T) { } t.Cleanup(cleanup) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false) err = store.SaveAccount(context.Background(), account) if err != nil { @@ -641,7 +645,7 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) { } t.Cleanup(cleanup) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false) account.Users[mockServiceUserID] = tt.serviceUser err = store.SaveAccount(context.Background(), account) @@ -680,7 +684,7 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) { } t.Cleanup(cleanup) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false) err = store.SaveAccount(context.Background(), account) if err != nil { @@ -707,7 +711,7 @@ func TestUser_DeleteUser_regularUser(t *testing.T) { } t.Cleanup(cleanup) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false) targetId := "user2" account.Users[targetId] = &types.User{ @@ -801,7 +805,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) { } t.Cleanup(cleanup) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false) targetId := "user2" account.Users[targetId] = &types.User{ @@ -969,7 +973,7 @@ func TestDefaultAccountManager_GetUser(t *testing.T) { } t.Cleanup(cleanup) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false) err = store.SaveAccount(context.Background(), account) if err != nil { @@ -1005,9 +1009,9 @@ func TestDefaultAccountManager_ListUsers(t *testing.T) { } t.Cleanup(cleanup) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false) - account.Users["normal_user1"] = types.NewRegularUser("normal_user1") - account.Users["normal_user2"] = types.NewRegularUser("normal_user2") + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false) + account.Users["normal_user1"] = types.NewRegularUser("normal_user1", "", "") + account.Users["normal_user2"] = types.NewRegularUser("normal_user2", "", "") err = store.SaveAccount(context.Background(), account) if err != nil { @@ -1047,7 +1051,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) { } t.Cleanup(cleanup) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false) externalUser := &types.User{ Id: "externalUser", Role: types.UserRoleUser, @@ -1104,7 +1108,7 @@ func TestUser_IsAdmin(t *testing.T) { user := types.NewAdminUser(mockUserID) assert.True(t, user.HasAdminPower()) - user = types.NewRegularUser(mockUserID) + user = types.NewRegularUser(mockUserID, "", "") assert.False(t, user.HasAdminPower()) } @@ -1115,7 +1119,7 @@ func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) { } t.Cleanup(cleanup) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false) account.Users[mockServiceUserID] = &types.User{ Id: mockServiceUserID, Role: "user", @@ -1149,7 +1153,7 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) { } t.Cleanup(cleanup) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false) account.Users[mockServiceUserID] = &types.User{ Id: mockServiceUserID, Role: "user", @@ -1320,13 +1324,13 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // create an account and an admin user - account, err := manager.GetOrCreateAccountByUser(context.Background(), ownerUserID, "netbird.io") + account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: ownerUserID, Domain: "netbird.io"}) if err != nil { t.Fatal(err) } // create other users - account.Users[regularUserID] = types.NewRegularUser(regularUserID) + account.Users[regularUserID] = types.NewRegularUser(regularUserID, "", "") account.Users[adminUserID] = types.NewAdminUser(adminUserID) account.Users[serviceUserID] = &types.User{IsServiceUser: true, Id: serviceUserID, Role: types.UserRoleAdmin, ServiceUserName: "service"} err = manager.Store.SaveAccount(context.Background(), account) @@ -1379,11 +1383,11 @@ func TestUserAccountPeersUpdate(t *testing.T) { updateManager.CloseChannel(context.Background(), peer1.ID) }) - // Creating a new regular user should not update account peers and not send peer update + // Creating a new regular user should send peer update (as users are not filtered yet) t.Run("creating new regular user with no groups", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg) close(done) }() @@ -1402,11 +1406,11 @@ func TestUserAccountPeersUpdate(t *testing.T) { } }) - // updating user with no linked peers should not update account peers and not send peer update + // updating user with no linked peers should update account peers and send peer update (as users are not filtered yet) t.Run("updating user with no linked peers", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg) close(done) }() @@ -1516,7 +1520,7 @@ func TestSaveOrAddUser_PreventAccountSwitch(t *testing.T) { } t.Cleanup(cleanup) - account1 := newAccountWithId(context.Background(), "account1", "ownerAccount1", "", false) + account1 := newAccountWithId(context.Background(), "account1", "ownerAccount1", "", "", "", false) targetId := "user2" account1.Users[targetId] = &types.User{ Id: targetId, @@ -1525,7 +1529,7 @@ func TestSaveOrAddUser_PreventAccountSwitch(t *testing.T) { } require.NoError(t, s.SaveAccount(context.Background(), account1)) - account2 := newAccountWithId(context.Background(), "account2", "ownerAccount2", "", false) + account2 := newAccountWithId(context.Background(), "account2", "ownerAccount2", "", "", "", false) require.NoError(t, s.SaveAccount(context.Background(), account2)) permissionsManager := permissions.NewManager(s) @@ -1552,7 +1556,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { } t.Cleanup(cleanup) - account1 := newAccountWithId(context.Background(), "account1", "account1Owner", "", false) + account1 := newAccountWithId(context.Background(), "account1", "account1Owner", "", "", "", false) account1.Settings.RegularUsersViewBlocked = false account1.Users["blocked-user"] = &types.User{ Id: "blocked-user", @@ -1574,7 +1578,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { } require.NoError(t, store.SaveAccount(context.Background(), account1)) - account2 := newAccountWithId(context.Background(), "account2", "account2Owner", "", false) + account2 := newAccountWithId(context.Background(), "account2", "account2Owner", "", "", "", false) account2.Users["settings-blocked-user"] = &types.User{ Id: "settings-blocked-user", Role: types.UserRoleUser, @@ -1771,7 +1775,7 @@ func TestApproveUser(t *testing.T) { } // Create account with admin and pending approval user - account := newAccountWithId(context.Background(), "account-1", "admin-user", "example.com", false) + account := newAccountWithId(context.Background(), "account-1", "admin-user", "example.com", "", "", false) err = manager.Store.SaveAccount(context.Background(), account) require.NoError(t, err) @@ -1782,7 +1786,7 @@ func TestApproveUser(t *testing.T) { require.NoError(t, err) // Create user pending approval - pendingUser := types.NewRegularUser("pending-user") + pendingUser := types.NewRegularUser("pending-user", "", "") pendingUser.AccountID = account.Id pendingUser.Blocked = true pendingUser.PendingApproval = true @@ -1807,12 +1811,12 @@ func TestApproveUser(t *testing.T) { assert.Contains(t, err.Error(), "not pending approval") // Test approval by non-admin should fail - regularUser := types.NewRegularUser("regular-user") + regularUser := types.NewRegularUser("regular-user", "", "") regularUser.AccountID = account.Id err = manager.Store.SaveUser(context.Background(), regularUser) require.NoError(t, err) - pendingUser2 := types.NewRegularUser("pending-user-2") + pendingUser2 := types.NewRegularUser("pending-user-2", "", "") pendingUser2.AccountID = account.Id pendingUser2.Blocked = true pendingUser2.PendingApproval = true @@ -1830,7 +1834,7 @@ func TestRejectUser(t *testing.T) { } // Create account with admin and pending approval user - account := newAccountWithId(context.Background(), "account-1", "admin-user", "example.com", false) + account := newAccountWithId(context.Background(), "account-1", "admin-user", "example.com", "", "", false) err = manager.Store.SaveAccount(context.Background(), account) require.NoError(t, err) @@ -1841,7 +1845,7 @@ func TestRejectUser(t *testing.T) { require.NoError(t, err) // Create user pending approval - pendingUser := types.NewRegularUser("pending-user") + pendingUser := types.NewRegularUser("pending-user", "", "") pendingUser.AccountID = account.Id pendingUser.Blocked = true pendingUser.PendingApproval = true @@ -1857,7 +1861,7 @@ func TestRejectUser(t *testing.T) { require.Error(t, err) // Test rejection of non-pending user should fail - regularUser := types.NewRegularUser("regular-user") + regularUser := types.NewRegularUser("regular-user", "", "") regularUser.AccountID = account.Id err = manager.Store.SaveUser(context.Background(), regularUser) require.NoError(t, err) @@ -1867,7 +1871,7 @@ func TestRejectUser(t *testing.T) { assert.Contains(t, err.Error(), "not pending approval") // Test rejection by non-admin should fail - pendingUser2 := types.NewRegularUser("pending-user-2") + pendingUser2 := types.NewRegularUser("pending-user-2", "", "") pendingUser2.AccountID = account.Id pendingUser2.Blocked = true pendingUser2.PendingApproval = true @@ -1877,3 +1881,149 @@ func TestRejectUser(t *testing.T) { err = manager.RejectUser(context.Background(), account.Id, regularUser.Id, pendingUser2.Id) require.Error(t, err) } + +func TestUser_Operations_WithEmbeddedIDP(t *testing.T) { + ctx := context.Background() + + // Create temporary directory for Dex + tmpDir := t.TempDir() + dexDataDir := tmpDir + "/dex" + require.NoError(t, os.MkdirAll(dexDataDir, 0700)) + + // Create embedded IDP config + embeddedIdPConfig := &idp.EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + Storage: idp.EmbeddedStorageConfig{ + Type: "sqlite3", + Config: idp.EmbeddedStorageTypeConfig{ + File: dexDataDir + "/dex.db", + }, + }, + } + + // Create embedded IDP manager + embeddedIdp, err := idp.NewEmbeddedIdPManager(ctx, embeddedIdPConfig, nil) + require.NoError(t, err) + defer func() { _ = embeddedIdp.Stop(ctx) }() + + // Create test store + testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", tmpDir) + require.NoError(t, err) + defer cleanup() + + // Create account with owner user + account := newAccountWithId(ctx, mockAccountID, mockUserID, "", "owner@test.com", "Owner User", false) + require.NoError(t, testStore.SaveAccount(ctx, account)) + + // Create mock network map controller + ctrl := gomock.NewController(t) + networkMapControllerMock := network_map.NewMockController(ctrl) + networkMapControllerMock.EXPECT(). + OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil). + AnyTimes() + + // Create account manager with embedded IDP + permissionsManager := permissions.NewManager(testStore) + am := DefaultAccountManager{ + Store: testStore, + eventStore: &activity.InMemoryEventStore{}, + permissionsManager: permissionsManager, + idpManager: embeddedIdp, + cacheLoading: map[string]chan struct{}{}, + networkMapController: networkMapControllerMock, + } + + // Initialize cache manager + cacheStore, err := nbcache.NewStore(ctx, nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval, nbcache.DefaultIDPCacheOpenConn) + require.NoError(t, err) + am.cacheManager = nbcache.NewAccountUserDataCache(am.loadAccount, cacheStore) + am.externalCacheManager = nbcache.NewUserDataCache(cacheStore) + + t.Run("create regular user returns password", func(t *testing.T) { + userInfo, err := am.CreateUser(ctx, mockAccountID, mockUserID, &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + IsServiceUser: false, + }) + require.NoError(t, err) + require.NotNil(t, userInfo) + + // Verify user data + assert.Equal(t, "newuser@test.com", userInfo.Email) + assert.Equal(t, "New User", userInfo.Name) + assert.Equal(t, "user", userInfo.Role) + assert.NotEmpty(t, userInfo.ID) + + // IMPORTANT: Password should be returned for embedded IDP + assert.NotEmpty(t, userInfo.Password, "Password should be returned for embedded IDP user") + t.Logf("Created user: ID=%s, Email=%s, Password=%s", userInfo.ID, userInfo.Email, userInfo.Password) + + // Verify user ID is in Dex encoded format + rawUserID, connectorID, err := dex.DecodeDexUserID(userInfo.ID) + require.NoError(t, err) + assert.NotEmpty(t, rawUserID) + assert.Equal(t, "local", connectorID) + t.Logf("Decoded user ID: rawUserID=%s, connectorID=%s", rawUserID, connectorID) + + // Verify user exists in database with correct data + dbUser, err := testStore.GetUserByUserID(ctx, store.LockingStrengthNone, userInfo.ID) + require.NoError(t, err) + assert.Equal(t, "newuser@test.com", dbUser.Email) + assert.Equal(t, "New User", dbUser.Name) + + // Store user ID for delete test + createdUserID := userInfo.ID + + t.Run("delete user works", func(t *testing.T) { + err := am.DeleteUser(ctx, mockAccountID, mockUserID, createdUserID) + require.NoError(t, err) + + // Verify user is deleted from database + _, err = testStore.GetUserByUserID(ctx, store.LockingStrengthNone, createdUserID) + assert.Error(t, err, "User should be deleted from database") + }) + }) + + t.Run("create service user does not return password", func(t *testing.T) { + userInfo, err := am.CreateUser(ctx, mockAccountID, mockUserID, &types.UserInfo{ + Name: "Service User", + Role: "user", + AutoGroups: []string{}, + IsServiceUser: true, + }) + require.NoError(t, err) + require.NotNil(t, userInfo) + + assert.True(t, userInfo.IsServiceUser) + assert.Equal(t, "Service User", userInfo.Name) + // Service users don't have passwords + assert.Empty(t, userInfo.Password, "Service users should not have passwords") + }) + + t.Run("duplicate email fails", func(t *testing.T) { + // Create first user + _, err := am.CreateUser(ctx, mockAccountID, mockUserID, &types.UserInfo{ + Email: "duplicate@test.com", + Name: "First User", + Role: "user", + AutoGroups: []string{}, + IsServiceUser: false, + }) + require.NoError(t, err) + + // Try to create second user with same email + _, err = am.CreateUser(ctx, mockAccountID, mockUserID, &types.UserInfo{ + Email: "duplicate@test.com", + Name: "Second User", + Role: "user", + AutoGroups: []string{}, + IsServiceUser: false, + }) + assert.Error(t, err, "Creating user with duplicate email should fail") + t.Logf("Duplicate email error: %v", err) + }) +} diff --git a/release_files/freebsd-port-diff.sh b/release_files/freebsd-port-diff.sh new file mode 100755 index 000000000..b030b9164 --- /dev/null +++ b/release_files/freebsd-port-diff.sh @@ -0,0 +1,216 @@ +#!/bin/bash +# +# FreeBSD Port Diff Generator for NetBird +# +# This script generates the diff file required for submitting a FreeBSD port update. +# It works on macOS, Linux, and FreeBSD by fetching files from FreeBSD cgit and +# computing checksums from the Go module proxy. +# +# Usage: ./freebsd-port-diff.sh [new_version] +# Example: ./freebsd-port-diff.sh 0.60.7 +# +# If no version is provided, it fetches the latest from GitHub. + +set -e + +GITHUB_REPO="netbirdio/netbird" +PORTS_CGIT_BASE="https://cgit.freebsd.org/ports/plain/security/netbird" +GO_PROXY="https://proxy.golang.org/github.com/netbirdio/netbird/@v" +OUTPUT_DIR="${OUTPUT_DIR:-.}" +AWK_FIRST_FIELD='{print $1}' + +fetch_all_tags() { + curl -sL "https://github.com/${GITHUB_REPO}/tags" 2>/dev/null | \ + grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+' | \ + sed 's/.*\/v//' | \ + sort -u -V + return 0 +} + +fetch_current_ports_version() { + echo "Fetching current version from FreeBSD ports..." >&2 + curl -sL "${PORTS_CGIT_BASE}/Makefile" 2>/dev/null | \ + grep -E "^DISTVERSION=" | \ + sed 's/DISTVERSION=[[:space:]]*//' | \ + tr -d '\t ' + return 0 +} + +fetch_latest_github_release() { + echo "Fetching latest release from GitHub..." >&2 + fetch_all_tags | tail -1 + return 0 +} + +fetch_ports_file() { + local filename="$1" + curl -sL "${PORTS_CGIT_BASE}/${filename}" 2>/dev/null + return 0 +} + +compute_checksums() { + local version="$1" + local tmpdir + tmpdir=$(mktemp -d) + # shellcheck disable=SC2064 + trap "rm -rf '$tmpdir'" EXIT + + echo "Downloading files from Go module proxy for v${version}..." >&2 + + local mod_file="${tmpdir}/v${version}.mod" + local zip_file="${tmpdir}/v${version}.zip" + + curl -sL "${GO_PROXY}/v${version}.mod" -o "$mod_file" 2>/dev/null + curl -sL "${GO_PROXY}/v${version}.zip" -o "$zip_file" 2>/dev/null + + if [[ ! -s "$mod_file" ]] || [[ ! -s "$zip_file" ]]; then + echo "Error: Could not download files from Go module proxy" >&2 + return 1 + fi + + local mod_sha256 mod_size zip_sha256 zip_size + + if command -v sha256sum &>/dev/null; then + mod_sha256=$(sha256sum "$mod_file" | awk "$AWK_FIRST_FIELD") + zip_sha256=$(sha256sum "$zip_file" | awk "$AWK_FIRST_FIELD") + elif command -v shasum &>/dev/null; then + mod_sha256=$(shasum -a 256 "$mod_file" | awk "$AWK_FIRST_FIELD") + zip_sha256=$(shasum -a 256 "$zip_file" | awk "$AWK_FIRST_FIELD") + else + echo "Error: No sha256 command found" >&2 + return 1 + fi + + if [[ "$OSTYPE" == "darwin"* ]]; then + mod_size=$(stat -f%z "$mod_file") + zip_size=$(stat -f%z "$zip_file") + else + mod_size=$(stat -c%s "$mod_file") + zip_size=$(stat -c%s "$zip_file") + fi + + echo "TIMESTAMP = $(date +%s)" + echo "SHA256 (go/security_netbird/netbird-v${version}/v${version}.mod) = ${mod_sha256}" + echo "SIZE (go/security_netbird/netbird-v${version}/v${version}.mod) = ${mod_size}" + echo "SHA256 (go/security_netbird/netbird-v${version}/v${version}.zip) = ${zip_sha256}" + echo "SIZE (go/security_netbird/netbird-v${version}/v${version}.zip) = ${zip_size}" + return 0 +} + +generate_new_makefile() { + local new_version="$1" + local old_makefile="$2" + + # Check if old version had PORTREVISION + if echo "$old_makefile" | grep -q "^PORTREVISION="; then + # Remove PORTREVISION line and update DISTVERSION + echo "$old_makefile" | \ + sed "s/^DISTVERSION=.*/DISTVERSION= ${new_version}/" | \ + grep -v "^PORTREVISION=" + else + # Just update DISTVERSION + echo "$old_makefile" | \ + sed "s/^DISTVERSION=.*/DISTVERSION= ${new_version}/" + fi + return 0 +} + +# Parse arguments +NEW_VERSION="${1:-}" + +# Auto-detect versions if not provided +OLD_VERSION=$(fetch_current_ports_version) +if [[ -z "$OLD_VERSION" ]]; then + echo "Error: Could not fetch current version from FreeBSD ports" >&2 + exit 1 +fi +echo "Current FreeBSD ports version: ${OLD_VERSION}" >&2 + +if [[ -z "$NEW_VERSION" ]]; then + NEW_VERSION=$(fetch_latest_github_release) + if [[ -z "$NEW_VERSION" ]]; then + echo "Error: Could not fetch latest release from GitHub" >&2 + exit 1 + fi +fi +echo "Target version: ${NEW_VERSION}" >&2 + +if [[ "$OLD_VERSION" = "$NEW_VERSION" ]]; then + echo "Port is already at version ${NEW_VERSION}. Nothing to do." >&2 + exit 0 +fi + +echo "" >&2 + +# Fetch current files +echo "Fetching current Makefile from FreeBSD ports..." >&2 +OLD_MAKEFILE=$(fetch_ports_file "Makefile") +if [[ -z "$OLD_MAKEFILE" ]]; then + echo "Error: Could not fetch Makefile" >&2 + exit 1 +fi + +echo "Fetching current distinfo from FreeBSD ports..." >&2 +OLD_DISTINFO=$(fetch_ports_file "distinfo") +if [[ -z "$OLD_DISTINFO" ]]; then + echo "Error: Could not fetch distinfo" >&2 + exit 1 +fi + +# Generate new files +echo "Generating new Makefile..." >&2 +NEW_MAKEFILE=$(generate_new_makefile "$NEW_VERSION" "$OLD_MAKEFILE") + +echo "Computing checksums for new version..." >&2 +NEW_DISTINFO=$(compute_checksums "$NEW_VERSION") +if [[ -z "$NEW_DISTINFO" ]]; then + echo "Error: Could not compute checksums" >&2 + exit 1 +fi + +# Create temp files for diff +TMPDIR=$(mktemp -d) +# shellcheck disable=SC2064 +trap "rm -rf '$TMPDIR'" EXIT + +mkdir -p "${TMPDIR}/a/security/netbird" "${TMPDIR}/b/security/netbird" + +echo "$OLD_MAKEFILE" > "${TMPDIR}/a/security/netbird/Makefile" +echo "$OLD_DISTINFO" > "${TMPDIR}/a/security/netbird/distinfo" +echo "$NEW_MAKEFILE" > "${TMPDIR}/b/security/netbird/Makefile" +echo "$NEW_DISTINFO" > "${TMPDIR}/b/security/netbird/distinfo" + +# Generate diff +OUTPUT_FILE="${OUTPUT_DIR}/netbird-${NEW_VERSION}.diff" + +echo "" >&2 +echo "Generating diff..." >&2 + +# Generate diff and clean up temp paths to show standard a/b paths +(cd "${TMPDIR}" && diff -ruN "a/security/netbird" "b/security/netbird") > "$OUTPUT_FILE" || true + +if [[ ! -s "$OUTPUT_FILE" ]]; then + echo "Error: Generated diff is empty" >&2 + exit 1 +fi + +echo "" >&2 +echo "=========================================" +echo "Diff saved to: ${OUTPUT_FILE}" +echo "=========================================" +echo "" +cat "$OUTPUT_FILE" +echo "" +echo "=========================================" +echo "" +echo "Next steps:" +echo "1. Review the diff above" +echo "2. Submit to https://bugs.freebsd.org/bugzilla/" +echo "3. Use ./freebsd-port-issue-body.sh to generate the issue content" +echo "" +echo "For FreeBSD testing (optional but recommended):" +echo " cd /usr/ports/security/netbird" +echo " patch < ${OUTPUT_FILE}" +echo " make stage && make stage-qa && make package && make install" +echo " netbird status" +echo " make deinstall" diff --git a/release_files/freebsd-port-issue-body.sh b/release_files/freebsd-port-issue-body.sh new file mode 100755 index 000000000..b7ad0f5b1 --- /dev/null +++ b/release_files/freebsd-port-issue-body.sh @@ -0,0 +1,159 @@ +#!/bin/bash +# +# FreeBSD Port Issue Body Generator for NetBird +# +# This script generates the issue body content for submitting a FreeBSD port update +# to the FreeBSD Bugzilla at https://bugs.freebsd.org/bugzilla/ +# +# Usage: ./freebsd-port-issue-body.sh [old_version] [new_version] +# Example: ./freebsd-port-issue-body.sh 0.56.0 0.59.1 +# +# If no versions are provided, the script will: +# - Fetch OLD version from FreeBSD ports cgit (current version in ports tree) +# - Fetch NEW version from latest NetBird GitHub release tag + +set -e + +GITHUB_REPO="netbirdio/netbird" +PORTS_CGIT_URL="https://cgit.freebsd.org/ports/plain/security/netbird/Makefile" + +fetch_current_ports_version() { + echo "Fetching current version from FreeBSD ports..." >&2 + local makefile_content + makefile_content=$(curl -sL "$PORTS_CGIT_URL" 2>/dev/null) + if [[ -z "$makefile_content" ]]; then + echo "Error: Could not fetch Makefile from FreeBSD ports" >&2 + return 1 + fi + echo "$makefile_content" | grep -E "^DISTVERSION=" | sed 's/DISTVERSION=[[:space:]]*//' | tr -d '\t ' + return 0 +} + +fetch_all_tags() { + # Fetch tags from GitHub tags page (no rate limiting, no auth needed) + curl -sL "https://github.com/${GITHUB_REPO}/tags" 2>/dev/null | \ + grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+' | \ + sed 's/.*\/v//' | \ + sort -u -V + return 0 +} + +fetch_latest_github_release() { + echo "Fetching latest release from GitHub..." >&2 + local latest + + # Fetch from GitHub tags page + latest=$(fetch_all_tags | tail -1) + + if [[ -z "$latest" ]]; then + # Fallback to GitHub API + latest=$(curl -sL "https://api.github.com/repos/${GITHUB_REPO}/releases/latest" 2>/dev/null | \ + grep '"tag_name"' | sed 's/.*"tag_name": *"v\([^"]*\)".*/\1/') + fi + + if [[ -z "$latest" ]]; then + echo "Error: Could not fetch latest release from GitHub" >&2 + return 1 + fi + echo "$latest" + return 0 +} + +OLD_VERSION="${1:-}" +NEW_VERSION="${2:-}" + +if [[ -z "$OLD_VERSION" ]]; then + OLD_VERSION=$(fetch_current_ports_version) + if [[ -z "$OLD_VERSION" ]]; then + echo "Error: Could not determine old version. Please provide it manually." >&2 + echo "Usage: $0 " >&2 + exit 1 + fi + echo "Detected OLD version from FreeBSD ports: $OLD_VERSION" >&2 +fi + +if [[ -z "$NEW_VERSION" ]]; then + NEW_VERSION=$(fetch_latest_github_release) + if [[ -z "$NEW_VERSION" ]]; then + echo "Error: Could not determine new version. Please provide it manually." >&2 + echo "Usage: $0 " >&2 + exit 1 + fi + echo "Detected NEW version from GitHub: $NEW_VERSION" >&2 +fi + +if [[ "$OLD_VERSION" = "$NEW_VERSION" ]]; then + echo "Warning: OLD and NEW versions are the same ($OLD_VERSION). Port may already be up to date." >&2 +fi + +echo "" >&2 + +OUTPUT_DIR="${OUTPUT_DIR:-.}" + +fetch_releases_between_versions() { + echo "Fetching release history from GitHub..." >&2 + + # Fetch all tags and filter to those between OLD and NEW versions + fetch_all_tags | \ + while read -r ver; do + if [[ "$(printf '%s\n' "$OLD_VERSION" "$ver" | sort -V | head -n1)" = "$OLD_VERSION" ]] && \ + [[ "$(printf '%s\n' "$ver" "$NEW_VERSION" | sort -V | head -n1)" = "$ver" ]] && \ + [[ "$ver" != "$OLD_VERSION" ]]; then + echo "$ver" + fi + done + return 0 +} + +generate_changelog_section() { + local releases + releases=$(fetch_releases_between_versions) + + echo "Changelogs:" + if [[ -n "$releases" ]]; then + echo "$releases" | while read -r ver; do + echo "https://github.com/${GITHUB_REPO}/releases/tag/v${ver}" + done + else + echo "https://github.com/${GITHUB_REPO}/releases/tag/v${NEW_VERSION}" + fi + return 0 +} + +OUTPUT_FILE="${OUTPUT_DIR}/netbird-${NEW_VERSION}-issue.txt" + +cat << EOF > "$OUTPUT_FILE" +BUGZILLA ISSUE DETAILS +====================== + +Severity: Affects Some People + +Summary: security/netbird: Update to ${NEW_VERSION} + +Description: +------------ +security/netbird: Update ${OLD_VERSION} => ${NEW_VERSION} + +$(generate_changelog_section) + +Commit log: +https://github.com/${GITHUB_REPO}/compare/v${OLD_VERSION}...v${NEW_VERSION} +EOF + +echo "=========================================" +echo "Issue body saved to: ${OUTPUT_FILE}" +echo "=========================================" +echo "" +cat "$OUTPUT_FILE" +echo "" +echo "=========================================" +echo "" +echo "Next steps:" +echo "1. Go to https://bugs.freebsd.org/bugzilla/ and login" +echo "2. Click 'Report an update or defect to a port'" +echo "3. Fill in:" +echo " - Severity: Affects Some People" +echo " - Summary: security/netbird: Update to ${NEW_VERSION}" +echo " - Description: Copy content from ${OUTPUT_FILE}" +echo "4. Attach diff file: netbird-${NEW_VERSION}.diff" +echo "5. Submit the bug report" diff --git a/release_files/ui-post-install.sh b/release_files/ui-post-install.sh index f6e8ddf92..ff6c4ee9b 100644 --- a/release_files/ui-post-install.sh +++ b/release_files/ui-post-install.sh @@ -1,10 +1,15 @@ #!/bin/sh +set -e +set -u + # Check if netbird-ui is running -if pgrep -x -f /usr/bin/netbird-ui >/dev/null 2>&1; +pid="$(pgrep -x -f /usr/bin/netbird-ui || true)" +if [ -n "${pid}" ] then - runner=$(ps --no-headers -o '%U' -p $(pgrep -x -f /usr/bin/netbird-ui) | sed 's/^[ \t]*//;s/[ \t]*$//') + uid="$(cat /proc/"${pid}"/loginuid)" + username="$(id -nu "${uid}")" # Only re-run if it was already running pkill -x -f /usr/bin/netbird-ui >/dev/null 2>&1 - su -l - "$runner" -c 'nohup /usr/bin/netbird-ui > /dev/null 2>&1 &' + su - "${username}" -c 'nohup /usr/bin/netbird-ui > /dev/null 2>&1 &' fi diff --git a/shared/auth/jwt/extractor.go b/shared/auth/jwt/extractor.go index a41d5f07a..5806d1f4d 100644 --- a/shared/auth/jwt/extractor.go +++ b/shared/auth/jwt/extractor.go @@ -78,16 +78,18 @@ func parseTime(timeString string) time.Time { return parsedTime } -func (c ClaimsExtractor) audienceClaim(claimName string) string { - url, err := url.JoinPath(c.authAudience, claimName) +func (c *ClaimsExtractor) audienceClaim(claimName string) string { + audienceURL, err := url.JoinPath(c.authAudience, claimName) if err != nil { return c.authAudience + claimName // as it was previously } - return url + return audienceURL } -// ToUserAuth extracts user authentication information from a JWT token +// ToUserAuth extracts user authentication information from a JWT token. +// The token should contain standard claims like email, name, preferred_username. +// When using Dex, make sure to set getUserInfo: true to have these claims populated. func (c *ClaimsExtractor) ToUserAuth(token *jwt.Token) (auth.UserAuth, error) { claims := token.Claims.(jwt.MapClaims) userAuth := auth.UserAuth{} @@ -120,6 +122,21 @@ func (c *ClaimsExtractor) ToUserAuth(token *jwt.Token) (auth.UserAuth, error) { } } + // Extract email from standard "email" claim + if email, ok := claims["email"].(string); ok { + userAuth.Email = email + } + + // Extract name from standard "name" claim + if name, ok := claims["name"].(string); ok { + userAuth.Name = name + } + + // Extract name from standard "preferred_username" claim + if preferredName, ok := claims["preferred_username"].(string); ok { + userAuth.PreferredName = preferredName + } + return userAuth, nil } diff --git a/shared/auth/jwt/extractor_test.go b/shared/auth/jwt/extractor_test.go new file mode 100644 index 000000000..45529770d --- /dev/null +++ b/shared/auth/jwt/extractor_test.go @@ -0,0 +1,322 @@ +package jwt + +import ( + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestClaimsExtractor_ToUserAuth_ExtractsEmailAndName(t *testing.T) { + tests := []struct { + name string + claims jwt.MapClaims + userIDClaim string + audience string + expectedUserID string + expectedEmail string + expectedName string + expectError bool + }{ + { + name: "extracts email and name from standard claims", + claims: jwt.MapClaims{ + "sub": "user-123", + "email": "test@example.com", + "name": "Test User", + }, + userIDClaim: "sub", + expectedUserID: "user-123", + expectedEmail: "test@example.com", + expectedName: "Test User", + }, + { + name: "extracts Dex encoded user ID", + claims: jwt.MapClaims{ + "sub": "CiQ3YWFkOGMwNS0zMjg3LTQ3M2YtYjQyYS0zNjU1MDRiZjI1ZTcSBWxvY2Fs", + "email": "dex-user@example.com", + "name": "Dex User", + }, + userIDClaim: "sub", + expectedUserID: "CiQ3YWFkOGMwNS0zMjg3LTQ3M2YtYjQyYS0zNjU1MDRiZjI1ZTcSBWxvY2Fs", + expectedEmail: "dex-user@example.com", + expectedName: "Dex User", + }, + { + name: "handles missing email claim", + claims: jwt.MapClaims{ + "sub": "user-456", + "name": "User Without Email", + }, + userIDClaim: "sub", + expectedUserID: "user-456", + expectedEmail: "", + expectedName: "User Without Email", + }, + { + name: "handles missing name claim", + claims: jwt.MapClaims{ + "sub": "user-789", + "email": "noname@example.com", + }, + userIDClaim: "sub", + expectedUserID: "user-789", + expectedEmail: "noname@example.com", + expectedName: "", + }, + { + name: "handles missing both email and name", + claims: jwt.MapClaims{ + "sub": "user-minimal", + }, + userIDClaim: "sub", + expectedUserID: "user-minimal", + expectedEmail: "", + expectedName: "", + }, + { + name: "extracts preferred_username", + claims: jwt.MapClaims{ + "sub": "user-pref", + "email": "pref@example.com", + "name": "Preferred User", + "preferred_username": "prefuser", + }, + userIDClaim: "sub", + expectedUserID: "user-pref", + expectedEmail: "pref@example.com", + expectedName: "Preferred User", + }, + { + name: "fails when user ID claim is empty", + claims: jwt.MapClaims{ + "email": "test@example.com", + "name": "Test User", + }, + userIDClaim: "sub", + expectError: true, + }, + { + name: "uses custom user ID claim", + claims: jwt.MapClaims{ + "user_id": "custom-user-id", + "email": "custom@example.com", + "name": "Custom User", + }, + userIDClaim: "user_id", + expectedUserID: "custom-user-id", + expectedEmail: "custom@example.com", + expectedName: "Custom User", + }, + { + name: "extracts account ID with audience prefix", + claims: jwt.MapClaims{ + "sub": "user-with-account", + "email": "account@example.com", + "name": "Account User", + "https://api.netbird.io/wt_account_id": "account-123", + "https://api.netbird.io/wt_account_domain": "example.com", + }, + userIDClaim: "sub", + audience: "https://api.netbird.io", + expectedUserID: "user-with-account", + expectedEmail: "account@example.com", + expectedName: "Account User", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create extractor with options + opts := []ClaimsExtractorOption{} + if tt.userIDClaim != "" { + opts = append(opts, WithUserIDClaim(tt.userIDClaim)) + } + if tt.audience != "" { + opts = append(opts, WithAudience(tt.audience)) + } + extractor := NewClaimsExtractor(opts...) + + // Create a mock token with the claims + token := &jwt.Token{ + Claims: tt.claims, + } + + // Extract user auth + userAuth, err := extractor.ToUserAuth(token) + + if tt.expectError { + assert.Error(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.expectedUserID, userAuth.UserId) + assert.Equal(t, tt.expectedEmail, userAuth.Email) + assert.Equal(t, tt.expectedName, userAuth.Name) + }) + } +} + +func TestClaimsExtractor_ToUserAuth_PreferredUsername(t *testing.T) { + extractor := NewClaimsExtractor(WithUserIDClaim("sub")) + + claims := jwt.MapClaims{ + "sub": "user-123", + "email": "test@example.com", + "name": "Test User", + "preferred_username": "testuser", + } + + token := &jwt.Token{Claims: claims} + + userAuth, err := extractor.ToUserAuth(token) + require.NoError(t, err) + + assert.Equal(t, "user-123", userAuth.UserId) + assert.Equal(t, "test@example.com", userAuth.Email) + assert.Equal(t, "Test User", userAuth.Name) + assert.Equal(t, "testuser", userAuth.PreferredName) +} + +func TestClaimsExtractor_ToUserAuth_LastLogin(t *testing.T) { + extractor := NewClaimsExtractor( + WithUserIDClaim("sub"), + WithAudience("https://api.netbird.io"), + ) + + expectedTime := time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC) + + claims := jwt.MapClaims{ + "sub": "user-123", + "email": "test@example.com", + "https://api.netbird.io/nb_last_login": expectedTime.Format(time.RFC3339), + } + + token := &jwt.Token{Claims: claims} + + userAuth, err := extractor.ToUserAuth(token) + require.NoError(t, err) + + assert.Equal(t, expectedTime, userAuth.LastLogin) +} + +func TestClaimsExtractor_ToUserAuth_Invited(t *testing.T) { + extractor := NewClaimsExtractor( + WithUserIDClaim("sub"), + WithAudience("https://api.netbird.io"), + ) + + claims := jwt.MapClaims{ + "sub": "user-123", + "email": "invited@example.com", + "https://api.netbird.io/nb_invited": true, + } + + token := &jwt.Token{Claims: claims} + + userAuth, err := extractor.ToUserAuth(token) + require.NoError(t, err) + + assert.True(t, userAuth.Invited) +} + +func TestClaimsExtractor_ToGroups(t *testing.T) { + extractor := NewClaimsExtractor(WithUserIDClaim("sub")) + + tests := []struct { + name string + claims jwt.MapClaims + groupClaimName string + expectedGroups []string + }{ + { + name: "extracts groups from claim", + claims: jwt.MapClaims{ + "sub": "user-123", + "groups": []interface{}{"admin", "users", "developers"}, + }, + groupClaimName: "groups", + expectedGroups: []string{"admin", "users", "developers"}, + }, + { + name: "returns empty slice when claim missing", + claims: jwt.MapClaims{ + "sub": "user-123", + }, + groupClaimName: "groups", + expectedGroups: []string{}, + }, + { + name: "handles custom claim name", + claims: jwt.MapClaims{ + "sub": "user-123", + "user_roles": []interface{}{"role1", "role2"}, + }, + groupClaimName: "user_roles", + expectedGroups: []string{"role1", "role2"}, + }, + { + name: "filters non-string values", + claims: jwt.MapClaims{ + "sub": "user-123", + "groups": []interface{}{"admin", 123, "users", true}, + }, + groupClaimName: "groups", + expectedGroups: []string{"admin", "users"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + token := &jwt.Token{Claims: tt.claims} + groups := extractor.ToGroups(token, tt.groupClaimName) + assert.Equal(t, tt.expectedGroups, groups) + }) + } +} + +func TestClaimsExtractor_DefaultUserIDClaim(t *testing.T) { + // When no user ID claim is specified, it should default to "sub" + extractor := NewClaimsExtractor() + + claims := jwt.MapClaims{ + "sub": "default-user-id", + "email": "default@example.com", + } + + token := &jwt.Token{Claims: claims} + + userAuth, err := extractor.ToUserAuth(token) + require.NoError(t, err) + + assert.Equal(t, "default-user-id", userAuth.UserId) +} + +func TestClaimsExtractor_DexUserIDFormat(t *testing.T) { + // Test that the extractor correctly handles Dex's encoded user ID format + // Dex encodes user IDs as base64(protobuf{user_id, connector_id}) + extractor := NewClaimsExtractor(WithUserIDClaim("sub")) + + // This is an actual Dex-encoded user ID + dexEncodedID := "CiQ3YWFkOGMwNS0zMjg3LTQ3M2YtYjQyYS0zNjU1MDRiZjI1ZTcSBWxvY2Fs" + + claims := jwt.MapClaims{ + "sub": dexEncodedID, + "email": "dex@example.com", + "name": "Dex User", + } + + token := &jwt.Token{Claims: claims} + + userAuth, err := extractor.ToUserAuth(token) + require.NoError(t, err) + + // The extractor should pass through the encoded ID as-is + // Decoding is done elsewhere (e.g., in the Dex provider) + assert.Equal(t, dexEncodedID, userAuth.UserId) + assert.Equal(t, "dex@example.com", userAuth.Email) + assert.Equal(t, "Dex User", userAuth.Name) +} diff --git a/shared/auth/jwt/validator.go b/shared/auth/jwt/validator.go index 239447b96..ede7acea5 100644 --- a/shared/auth/jwt/validator.go +++ b/shared/auth/jwt/validator.go @@ -60,6 +60,7 @@ type Validator struct { keysLocation string idpSignkeyRefreshEnabled bool keys *Jwks + lastForcedRefresh time.Time } var ( @@ -84,26 +85,17 @@ func NewValidator(issuer string, audienceList []string, keysLocation string, idp } } +// forcedRefreshCooldown is the minimum time between forced key refreshes +// to prevent abuse from invalid tokens with fake kid values +const forcedRefreshCooldown = 30 * time.Second + func (v *Validator) getKeyFunc(ctx context.Context) jwt.Keyfunc { return func(token *jwt.Token) (interface{}, error) { // If keys are rotated, verify the keys prior to token validation if v.idpSignkeyRefreshEnabled { // If the keys are invalid, retrieve new ones - // @todo propose a separate go routine to regularly check these to prevent blocking when actually - // validating the token if !v.keys.stillValid() { - v.lock.Lock() - defer v.lock.Unlock() - - refreshedKeys, err := getPemKeys(v.keysLocation) - if err != nil { - log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err) - refreshedKeys = v.keys - } - - log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC()) - - v.keys = refreshedKeys + v.refreshKeys(ctx) } } @@ -112,6 +104,18 @@ func (v *Validator) getKeyFunc(ctx context.Context) jwt.Keyfunc { return publicKey, nil } + // If key not found and refresh is enabled, try refreshing keys and retry once. + // This handles the case where keys were rotated but cache hasn't expired yet. + // Use a cooldown to prevent abuse from tokens with fake kid values. + if errors.Is(err, errKeyNotFound) && v.idpSignkeyRefreshEnabled { + if v.forceRefreshKeys(ctx) { + publicKey, err = getPublicKey(token, v.keys) + if err == nil { + return publicKey, nil + } + } + } + msg := fmt.Sprintf("getPublicKey error: %s", err) if errors.Is(err, errKeyNotFound) && !v.idpSignkeyRefreshEnabled { msg = fmt.Sprintf("getPublicKey error: %s. You can enable key refresh by setting HttpServerConfig.IdpSignKeyRefreshEnabled to true in your management.json file and restart the service", err) @@ -123,6 +127,46 @@ func (v *Validator) getKeyFunc(ctx context.Context) jwt.Keyfunc { } } +func (v *Validator) refreshKeys(ctx context.Context) { + v.lock.Lock() + defer v.lock.Unlock() + + refreshedKeys, err := getPemKeys(v.keysLocation) + if err != nil { + log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err) + return + } + + log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC()) + v.keys = refreshedKeys +} + +// forceRefreshKeys refreshes keys if the cooldown period has passed. +// Returns true if keys were refreshed, false if cooldown prevented refresh. +// The cooldown check is done inside the lock to prevent race conditions. +func (v *Validator) forceRefreshKeys(ctx context.Context) bool { + v.lock.Lock() + defer v.lock.Unlock() + + // Check cooldown inside lock to prevent multiple goroutines from refreshing + if time.Since(v.lastForcedRefresh) <= forcedRefreshCooldown { + return false + } + + log.WithContext(ctx).Debugf("key not found in cache, forcing JWKS refresh") + + refreshedKeys, err := getPemKeys(v.keysLocation) + if err != nil { + log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err) + return false + } + + log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC()) + v.keys = refreshedKeys + v.lastForcedRefresh = time.Now() + return true +} + // ValidateAndParse validates the token and returns the parsed token func (v *Validator) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) { // If the token is empty... @@ -165,12 +209,12 @@ func (jwks *Jwks) stillValid() bool { func getPemKeys(keysLocation string) (*Jwks, error) { jwks := &Jwks{} - url, err := url.ParseRequestURI(keysLocation) + requestURI, err := url.ParseRequestURI(keysLocation) if err != nil { return jwks, err } - resp, err := http.Get(url.String()) + resp, err := http.Get(requestURI.String()) if err != nil { return jwks, err } diff --git a/shared/auth/user.go b/shared/auth/user.go index c1bae808e..00a3d2b64 100644 --- a/shared/auth/user.go +++ b/shared/auth/user.go @@ -18,6 +18,15 @@ type UserAuth struct { // The user id UserId string + // The user's email address + // (optional, may be empty if not in token, make sure to set getUserInfo: true in Dex to have this field) + Email string + // The user's name + // (optional, may be empty if not in token, make sure to set getUserInfo: true in Dex to have this field) + Name string + // The user's preferred name + // (optional, may be empty if not in token, make sure to set getUserInfo: true in Dex to have this field) + PreferredName string // Last login time for this user LastLogin time.Time // The Groups the user belongs to on this account diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go index 9fbe70948..64f6831f2 100644 --- a/shared/management/client/client_test.go +++ b/shared/management/client/client_test.go @@ -129,7 +129,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { if err != nil { t.Fatal(err) } - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, mgmt.MockIntegratedValidator{}, networkMapController) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, mgmt.MockIntegratedValidator{}, networkMapController, nil) if err != nil { t.Fatal(err) } diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index 520a83e36..89860ac9b 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -111,6 +111,8 @@ func (c *GrpcClient) ready() bool { // Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages // Blocking request. The result will be sent via msgHandler callback function func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error { + backOff := defaultBackoff(ctx) + operation := func() error { log.Debugf("management connection state %v", c.conn.GetState()) connState := c.conn.GetState() @@ -128,10 +130,10 @@ func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler return err } - return c.handleStream(ctx, *serverPubKey, sysInfo, msgHandler) + return c.handleStream(ctx, *serverPubKey, sysInfo, msgHandler, backOff) } - err := backoff.Retry(operation, defaultBackoff(ctx)) + err := backoff.Retry(operation, backOff) if err != nil { log.Warnf("exiting the Management service connection retry loop due to the unrecoverable error: %s", err) } @@ -140,7 +142,7 @@ func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler } func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info, - msgHandler func(msg *proto.SyncResponse) error) error { + msgHandler func(msg *proto.SyncResponse) error, backOff backoff.BackOff) error { ctx, cancelStream := context.WithCancel(ctx) defer cancelStream() @@ -158,6 +160,9 @@ func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key, // blocking until error err = c.receiveEvents(stream, serverPubKey, msgHandler) + // we need this reset because after a successful connection and a consequent error, backoff lib doesn't + // reset times and next try will start with a long delay + backOff.Reset() if err != nil { c.notifyDisconnected(err) s, _ := gstatus.FromError(err) diff --git a/shared/management/client/rest/client.go b/shared/management/client/rest/client.go index 2a5de5bbc..4d1de2631 100644 --- a/shared/management/client/rest/client.go +++ b/shared/management/client/rest/client.go @@ -16,6 +16,7 @@ type Client struct { managementURL string authHeader string httpClient HttpClient + userAgent string // Accounts NetBird account APIs // see more: https://docs.netbird.io/api/resources/accounts @@ -128,6 +129,9 @@ func (c *Client) NewRequest(ctx context.Context, method, path string, body io.Re if body != nil { req.Header.Add("Content-Type", "application/json") } + if c.userAgent != "" { + req.Header.Set("User-Agent", c.userAgent) + } if len(query) != 0 { q := req.URL.Query() diff --git a/shared/management/client/rest/client_test.go b/shared/management/client/rest/client_test.go index 54a0290d0..17df8dd8b 100644 --- a/shared/management/client/rest/client_test.go +++ b/shared/management/client/rest/client_test.go @@ -4,10 +4,14 @@ package rest_test import ( + "context" "net/http" "net/http/httptest" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" "github.com/netbirdio/netbird/shared/management/client/rest" ) @@ -32,3 +36,50 @@ func withBlackBoxServer(t *testing.T, callback func(*rest.Client)) { c := rest.New(server.URL, "nbp_apTmlmUXHSC4PKmHwtIZNaGr8eqcVI2gMURp") callback(c) } + +func TestClient_UserAgent_Set(t *testing.T) { + expectedUserAgent := "TestApp/1.2.3" + mux := &http.ServeMux{} + server := httptest.NewServer(mux) + defer server.Close() + + mux.HandleFunc("/api/accounts", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, expectedUserAgent, r.Header.Get("User-Agent")) + w.WriteHeader(200) + _, err := w.Write([]byte("[]")) + require.NoError(t, err) + }) + + c := rest.NewWithOptions( + rest.WithManagementURL(server.URL), + rest.WithPAT("test-token"), + rest.WithUserAgent(expectedUserAgent), + ) + + _, err := c.Accounts.List(context.Background()) + require.NoError(t, err) +} + +func TestClient_UserAgent_NotSet(t *testing.T) { + mux := &http.ServeMux{} + server := httptest.NewServer(mux) + defer server.Close() + + mux.HandleFunc("/api/accounts", func(w http.ResponseWriter, r *http.Request) { + // When no custom user agent is set, Go's default HTTP client will set one + // We just verify that the header exists (it will be Go's default) + userAgent := r.Header.Get("User-Agent") + assert.NotEmpty(t, userAgent) + w.WriteHeader(200) + _, err := w.Write([]byte("[]")) + require.NoError(t, err) + }) + + c := rest.NewWithOptions( + rest.WithManagementURL(server.URL), + rest.WithPAT("test-token"), + ) + + _, err := c.Accounts.List(context.Background()) + require.NoError(t, err) +} diff --git a/shared/management/client/rest/options.go b/shared/management/client/rest/options.go index 21f2394e9..17c7e15cd 100644 --- a/shared/management/client/rest/options.go +++ b/shared/management/client/rest/options.go @@ -42,3 +42,10 @@ func WithAuthHeader(value string) option { c.authHeader = value } } + +// WithUserAgent sets a custom User-Agent header for HTTP requests +func WithUserAgent(userAgent string) option { + return func(c *Client) { + c.userAgent = userAgent + } +} diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 2d063a7b5..64086e7ec 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -32,6 +32,10 @@ tags: - name: Ingress Ports description: Interact with and view information about the ingress peers and ports. x-cloud-only: true + - name: Identity Providers + description: Interact with and view information about identity providers. + - name: Instance + description: Instance setup and status endpoints for initial configuration. components: schemas: Account: @@ -145,6 +149,15 @@ components: description: Enables or disables experimental lazy connection type: boolean example: true + auto_update_version: + description: Set Clients auto-update version. "latest", "disabled", or a specific version (e.g "0.50.1") + type: string + example: "0.51.2" + embedded_idp_enabled: + description: Indicates whether the embedded identity provider (Dex) is enabled for this account. This is a read-only field. + type: boolean + readOnly: true + example: false required: - peer_login_expiration_enabled - peer_login_expiration @@ -202,6 +215,10 @@ components: description: User's email address type: string example: demo@netbird.io + password: + description: User's password. Only present when user is created (create user endpoint is called) and only when IdP supports user creation with password. + type: string + example: super_secure_password name: description: User's name from idp provider type: string @@ -248,6 +265,10 @@ components: description: How user was issued by API or Integration type: string example: api + idp_id: + description: Identity provider ID (connector ID) that the user authenticated with. Only populated for users with Dex-encoded user IDs. + type: string + example: okta-abc123 permissions: $ref: '#/components/schemas/UserPermissions' required: @@ -484,6 +505,8 @@ components: description: Indicates whether the peer is ephemeral or not type: boolean example: false + local_flags: + $ref: '#/components/schemas/PeerLocalFlags' required: - city_name - connected @@ -510,6 +533,49 @@ components: - serial_number - extra_dns_labels - ephemeral + PeerLocalFlags: + type: object + properties: + rosenpass_enabled: + description: Indicates whether Rosenpass is enabled on this peer + type: boolean + example: true + rosenpass_permissive: + description: Indicates whether Rosenpass is in permissive mode or not + type: boolean + example: false + server_ssh_allowed: + description: Indicates whether SSH access this peer is allowed or not + type: boolean + example: true + disable_client_routes: + description: Indicates whether client routes are disabled on this peer or not + type: boolean + example: false + disable_server_routes: + description: Indicates whether server routes are disabled on this peer or not + type: boolean + example: false + disable_dns: + description: Indicates whether DNS management is disabled on this peer or not + type: boolean + example: false + disable_firewall: + description: Indicates whether firewall management is disabled on this peer or not + type: boolean + example: false + block_lan_access: + description: Indicates whether LAN access is blocked on this peer when used as a routing peer + type: boolean + example: false + block_inbound: + description: Indicates whether inbound traffic is blocked on this peer + type: boolean + example: false + lazy_connection_enabled: + description: Indicates whether lazy connection is enabled on this peer + type: boolean + example: false PeerTemporaryAccessRequest: type: object properties: @@ -932,7 +998,7 @@ components: protocol: description: Policy rule type of the traffic type: string - enum: ["all", "tcp", "udp", "icmp"] + enum: ["all", "tcp", "udp", "icmp", "netbird-ssh"] example: "tcp" ports: description: Policy rule affected ports @@ -945,6 +1011,14 @@ components: type: array items: $ref: '#/components/schemas/RulePortRange' + authorized_groups: + description: Map of user group ids to a list of local users + type: object + additionalProperties: + type: array + items: + type: string + example: "group1" required: - name - enabled @@ -2193,6 +2267,118 @@ components: - page_size - total_records - total_pages + IdentityProviderType: + type: string + description: Type of identity provider + enum: + - oidc + - zitadel + - entra + - google + - okta + - pocketid + - microsoft + example: oidc + IdentityProvider: + type: object + properties: + id: + description: Identity provider ID + type: string + example: ch8i4ug6lnn4g9hqv7l0 + type: + $ref: '#/components/schemas/IdentityProviderType' + name: + description: Human-readable name for the identity provider + type: string + example: My OIDC Provider + issuer: + description: OIDC issuer URL + type: string + example: https://accounts.google.com + client_id: + description: OAuth2 client ID + type: string + example: 123456789.apps.googleusercontent.com + required: + - type + - name + - issuer + - client_id + IdentityProviderRequest: + type: object + properties: + type: + $ref: '#/components/schemas/IdentityProviderType' + name: + description: Human-readable name for the identity provider + type: string + example: My OIDC Provider + issuer: + description: OIDC issuer URL + type: string + example: https://accounts.google.com + client_id: + description: OAuth2 client ID + type: string + example: 123456789.apps.googleusercontent.com + client_secret: + description: OAuth2 client secret + type: string + example: secret123 + required: + - type + - name + - issuer + - client_id + - client_secret + InstanceStatus: + type: object + description: Instance status information + properties: + setup_required: + description: Indicates whether the instance requires initial setup + type: boolean + example: true + required: + - setup_required + SetupRequest: + type: object + description: Request to set up the initial admin user + properties: + email: + description: Email address for the admin user + type: string + example: admin@example.com + password: + description: Password for the admin user (minimum 8 characters) + type: string + format: password + minLength: 8 + example: securepassword123 + name: + description: Display name for the admin user (defaults to email if not provided) + type: string + example: Admin User + required: + - email + - password + - name + SetupResponse: + type: object + description: Response after successful instance setup + properties: + user_id: + description: The ID of the created user + type: string + example: abc123def456 + email: + description: Email address of the created user + type: string + example: admin@example.com + required: + - user_id + - email responses: not_found: description: Resource not found @@ -2230,6 +2416,48 @@ security: - BearerAuth: [ ] - TokenAuth: [ ] paths: + /api/instance: + get: + summary: Get Instance Status + description: Returns the instance status including whether initial setup is required. This endpoint does not require authentication. + tags: [ Instance ] + security: [ ] + responses: + '200': + description: Instance status information + content: + application/json: + schema: + $ref: '#/components/schemas/InstanceStatus' + '500': + "$ref": "#/components/responses/internal_error" + /api/setup: + post: + summary: Setup Instance + description: Creates the initial admin user for the instance. This endpoint does not require authentication but only works when setup is required (no accounts exist and embedded IDP is enabled). + tags: [ Instance ] + security: [ ] + requestBody: + description: Initial admin user details + required: true + content: + 'application/json': + schema: + $ref: '#/components/schemas/SetupRequest' + responses: + '200': + description: Setup completed successfully + content: + application/json: + schema: + $ref: '#/components/schemas/SetupResponse' + '400': + "$ref": "#/components/responses/bad_request" + '412': + description: Setup already completed + content: { } + '500': + "$ref": "#/components/responses/internal_error" /api/accounts: get: summary: List all Accounts @@ -4820,3 +5048,147 @@ paths: "$ref": "#/components/responses/forbidden" '500': "$ref": "#/components/responses/internal_error" + /api/identity-providers: + get: + summary: List all Identity Providers + description: Returns a list of all identity providers configured for the account + tags: [ Identity Providers ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + responses: + '200': + description: A JSON array of identity providers + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/IdentityProvider' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + post: + summary: Create an Identity Provider + description: Creates a new identity provider configuration + tags: [ Identity Providers ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + requestBody: + description: Identity provider configuration + content: + 'application/json': + schema: + $ref: '#/components/schemas/IdentityProviderRequest' + responses: + '200': + description: An Identity Provider object + content: + application/json: + schema: + $ref: '#/components/schemas/IdentityProvider' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/identity-providers/{idpId}: + get: + summary: Retrieve an Identity Provider + description: Get information about a specific identity provider + tags: [ Identity Providers ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: idpId + required: true + schema: + type: string + description: The unique identifier of an identity provider + responses: + '200': + description: An Identity Provider object + content: + application/json: + schema: + $ref: '#/components/schemas/IdentityProvider' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + put: + summary: Update an Identity Provider + description: Update an existing identity provider configuration + tags: [ Identity Providers ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: idpId + required: true + schema: + type: string + description: The unique identifier of an identity provider + requestBody: + description: Identity provider update + content: + 'application/json': + schema: + $ref: '#/components/schemas/IdentityProviderRequest' + responses: + '200': + description: An Identity Provider object + content: + application/json: + schema: + $ref: '#/components/schemas/IdentityProvider' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + delete: + summary: Delete an Identity Provider + description: Delete an identity provider configuration + tags: [ Identity Providers ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: idpId + required: true + schema: + type: string + description: The unique identifier of an identity provider + responses: + '200': + description: Delete status code + content: { } + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index d3e425548..ab5a65cb0 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -83,6 +83,17 @@ const ( GroupMinimumIssuedJwt GroupMinimumIssued = "jwt" ) +// Defines values for IdentityProviderType. +const ( + IdentityProviderTypeEntra IdentityProviderType = "entra" + IdentityProviderTypeGoogle IdentityProviderType = "google" + IdentityProviderTypeMicrosoft IdentityProviderType = "microsoft" + IdentityProviderTypeOidc IdentityProviderType = "oidc" + IdentityProviderTypeOkta IdentityProviderType = "okta" + IdentityProviderTypePocketid IdentityProviderType = "pocketid" + IdentityProviderTypeZitadel IdentityProviderType = "zitadel" +) + // Defines values for IngressPortAllocationPortMappingProtocol. const ( IngressPortAllocationPortMappingProtocolTcp IngressPortAllocationPortMappingProtocol = "tcp" @@ -130,10 +141,11 @@ const ( // Defines values for PolicyRuleProtocol. const ( - PolicyRuleProtocolAll PolicyRuleProtocol = "all" - PolicyRuleProtocolIcmp PolicyRuleProtocol = "icmp" - PolicyRuleProtocolTcp PolicyRuleProtocol = "tcp" - PolicyRuleProtocolUdp PolicyRuleProtocol = "udp" + PolicyRuleProtocolAll PolicyRuleProtocol = "all" + PolicyRuleProtocolIcmp PolicyRuleProtocol = "icmp" + PolicyRuleProtocolNetbirdSsh PolicyRuleProtocol = "netbird-ssh" + PolicyRuleProtocolTcp PolicyRuleProtocol = "tcp" + PolicyRuleProtocolUdp PolicyRuleProtocol = "udp" ) // Defines values for PolicyRuleMinimumAction. @@ -144,10 +156,11 @@ const ( // Defines values for PolicyRuleMinimumProtocol. const ( - PolicyRuleMinimumProtocolAll PolicyRuleMinimumProtocol = "all" - PolicyRuleMinimumProtocolIcmp PolicyRuleMinimumProtocol = "icmp" - PolicyRuleMinimumProtocolTcp PolicyRuleMinimumProtocol = "tcp" - PolicyRuleMinimumProtocolUdp PolicyRuleMinimumProtocol = "udp" + PolicyRuleMinimumProtocolAll PolicyRuleMinimumProtocol = "all" + PolicyRuleMinimumProtocolIcmp PolicyRuleMinimumProtocol = "icmp" + PolicyRuleMinimumProtocolNetbirdSsh PolicyRuleMinimumProtocol = "netbird-ssh" + PolicyRuleMinimumProtocolTcp PolicyRuleMinimumProtocol = "tcp" + PolicyRuleMinimumProtocolUdp PolicyRuleMinimumProtocol = "udp" ) // Defines values for PolicyRuleUpdateAction. @@ -158,10 +171,11 @@ const ( // Defines values for PolicyRuleUpdateProtocol. const ( - PolicyRuleUpdateProtocolAll PolicyRuleUpdateProtocol = "all" - PolicyRuleUpdateProtocolIcmp PolicyRuleUpdateProtocol = "icmp" - PolicyRuleUpdateProtocolTcp PolicyRuleUpdateProtocol = "tcp" - PolicyRuleUpdateProtocolUdp PolicyRuleUpdateProtocol = "udp" + PolicyRuleUpdateProtocolAll PolicyRuleUpdateProtocol = "all" + PolicyRuleUpdateProtocolIcmp PolicyRuleUpdateProtocol = "icmp" + PolicyRuleUpdateProtocolNetbirdSsh PolicyRuleUpdateProtocol = "netbird-ssh" + PolicyRuleUpdateProtocolTcp PolicyRuleUpdateProtocol = "tcp" + PolicyRuleUpdateProtocolUdp PolicyRuleUpdateProtocol = "udp" ) // Defines values for ResourceType. @@ -291,9 +305,15 @@ type AccountRequest struct { // AccountSettings defines model for AccountSettings. type AccountSettings struct { + // AutoUpdateVersion Set Clients auto-update version. "latest", "disabled", or a specific version (e.g "0.50.1") + AutoUpdateVersion *string `json:"auto_update_version,omitempty"` + // DnsDomain Allows to define a custom dns domain for the account - DnsDomain *string `json:"dns_domain,omitempty"` - Extra *AccountExtraSettings `json:"extra,omitempty"` + DnsDomain *string `json:"dns_domain,omitempty"` + + // EmbeddedIdpEnabled Indicates whether the embedded identity provider (Dex) is enabled for this account. This is a read-only field. + EmbeddedIdpEnabled *bool `json:"embedded_idp_enabled,omitempty"` + Extra *AccountExtraSettings `json:"extra,omitempty"` // GroupsPropagationEnabled Allows propagate the new user auto groups to peers that belongs to the user GroupsPropagationEnabled *bool `json:"groups_propagation_enabled,omitempty"` @@ -514,6 +534,45 @@ type GroupRequest struct { Resources *[]Resource `json:"resources,omitempty"` } +// IdentityProvider defines model for IdentityProvider. +type IdentityProvider struct { + // ClientId OAuth2 client ID + ClientId string `json:"client_id"` + + // Id Identity provider ID + Id *string `json:"id,omitempty"` + + // Issuer OIDC issuer URL + Issuer string `json:"issuer"` + + // Name Human-readable name for the identity provider + Name string `json:"name"` + + // Type Type of identity provider + Type IdentityProviderType `json:"type"` +} + +// IdentityProviderRequest defines model for IdentityProviderRequest. +type IdentityProviderRequest struct { + // ClientId OAuth2 client ID + ClientId string `json:"client_id"` + + // ClientSecret OAuth2 client secret + ClientSecret string `json:"client_secret"` + + // Issuer OIDC issuer URL + Issuer string `json:"issuer"` + + // Name Human-readable name for the identity provider + Name string `json:"name"` + + // Type Type of identity provider + Type IdentityProviderType `json:"type"` +} + +// IdentityProviderType Type of identity provider +type IdentityProviderType string + // IngressPeer defines model for IngressPeer. type IngressPeer struct { AvailablePorts AvailablePorts `json:"available_ports"` @@ -647,6 +706,12 @@ type IngressPortAllocationRequestPortRange struct { // IngressPortAllocationRequestPortRangeProtocol The protocol accepted by the port range type IngressPortAllocationRequestPortRangeProtocol string +// InstanceStatus Instance status information +type InstanceStatus struct { + // SetupRequired Indicates whether the instance requires initial setup + SetupRequired bool `json:"setup_required"` +} + // Location Describe geographical location information type Location struct { // CityName Commonly used English name of the city @@ -1074,7 +1139,8 @@ type Peer struct { LastLogin time.Time `json:"last_login"` // LastSeen Last time peer connected to Netbird's management service - LastSeen time.Time `json:"last_seen"` + LastSeen time.Time `json:"last_seen"` + LocalFlags *PeerLocalFlags `json:"local_flags,omitempty"` // LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not LoginExpirationEnabled bool `json:"login_expiration_enabled"` @@ -1164,7 +1230,8 @@ type PeerBatch struct { LastLogin time.Time `json:"last_login"` // LastSeen Last time peer connected to Netbird's management service - LastSeen time.Time `json:"last_seen"` + LastSeen time.Time `json:"last_seen"` + LocalFlags *PeerLocalFlags `json:"local_flags,omitempty"` // LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not LoginExpirationEnabled bool `json:"login_expiration_enabled"` @@ -1194,6 +1261,39 @@ type PeerBatch struct { Version string `json:"version"` } +// PeerLocalFlags defines model for PeerLocalFlags. +type PeerLocalFlags struct { + // BlockInbound Indicates whether inbound traffic is blocked on this peer + BlockInbound *bool `json:"block_inbound,omitempty"` + + // BlockLanAccess Indicates whether LAN access is blocked on this peer when used as a routing peer + BlockLanAccess *bool `json:"block_lan_access,omitempty"` + + // DisableClientRoutes Indicates whether client routes are disabled on this peer or not + DisableClientRoutes *bool `json:"disable_client_routes,omitempty"` + + // DisableDns Indicates whether DNS management is disabled on this peer or not + DisableDns *bool `json:"disable_dns,omitempty"` + + // DisableFirewall Indicates whether firewall management is disabled on this peer or not + DisableFirewall *bool `json:"disable_firewall,omitempty"` + + // DisableServerRoutes Indicates whether server routes are disabled on this peer or not + DisableServerRoutes *bool `json:"disable_server_routes,omitempty"` + + // LazyConnectionEnabled Indicates whether lazy connection is enabled on this peer + LazyConnectionEnabled *bool `json:"lazy_connection_enabled,omitempty"` + + // RosenpassEnabled Indicates whether Rosenpass is enabled on this peer + RosenpassEnabled *bool `json:"rosenpass_enabled,omitempty"` + + // RosenpassPermissive Indicates whether Rosenpass is in permissive mode or not + RosenpassPermissive *bool `json:"rosenpass_permissive,omitempty"` + + // ServerSshAllowed Indicates whether SSH access this peer is allowed or not + ServerSshAllowed *bool `json:"server_ssh_allowed,omitempty"` +} + // PeerMinimum defines model for PeerMinimum. type PeerMinimum struct { // Id Peer ID @@ -1346,6 +1446,9 @@ type PolicyRule struct { // Action Policy rule accept or drops packets Action PolicyRuleAction `json:"action"` + // AuthorizedGroups Map of user group ids to a list of local users + AuthorizedGroups *map[string][]string `json:"authorized_groups,omitempty"` + // Bidirectional Define if the rule is applicable in both directions, sources, and destinations. Bidirectional bool `json:"bidirectional"` @@ -1390,6 +1493,9 @@ type PolicyRuleMinimum struct { // Action Policy rule accept or drops packets Action PolicyRuleMinimumAction `json:"action"` + // AuthorizedGroups Map of user group ids to a list of local users + AuthorizedGroups *map[string][]string `json:"authorized_groups,omitempty"` + // Bidirectional Define if the rule is applicable in both directions, sources, and destinations. Bidirectional bool `json:"bidirectional"` @@ -1423,6 +1529,9 @@ type PolicyRuleUpdate struct { // Action Policy rule accept or drops packets Action PolicyRuleUpdateAction `json:"action"` + // AuthorizedGroups Map of user group ids to a list of local users + AuthorizedGroups *map[string][]string `json:"authorized_groups,omitempty"` + // Bidirectional Define if the rule is applicable in both directions, sources, and destinations. Bidirectional bool `json:"bidirectional"` @@ -1783,6 +1892,27 @@ type SetupKeyRequest struct { Revoked bool `json:"revoked"` } +// SetupRequest Request to set up the initial admin user +type SetupRequest struct { + // Email Email address for the admin user + Email string `json:"email"` + + // Name Display name for the admin user (defaults to email if not provided) + Name string `json:"name"` + + // Password Password for the admin user (minimum 8 characters) + Password string `json:"password"` +} + +// SetupResponse Response after successful instance setup +type SetupResponse struct { + // Email Email address of the created user + Email string `json:"email"` + + // UserId The ID of the created user + UserId string `json:"user_id"` +} + // User defines model for User. type User struct { // AutoGroups Group IDs to auto-assign to peers registered by this user @@ -1794,6 +1924,9 @@ type User struct { // Id User ID Id string `json:"id"` + // IdpId Identity provider ID (connector ID) that the user authenticated with. Only populated for users with Dex-encoded user IDs. + IdpId *string `json:"idp_id,omitempty"` + // IsBlocked Is true if this user is blocked. Blocked users can't use the system IsBlocked bool `json:"is_blocked"` @@ -1812,6 +1945,9 @@ type User struct { // Name User's name from idp provider Name string `json:"name"` + // Password User's password. Only present when user is created (create user endpoint is called) and only when IdP supports user creation with password. + Password *string `json:"password,omitempty"` + // PendingApproval Is true if this user requires approval before being activated. Only applicable for users joining via domain matching when user_approval_required is enabled. PendingApproval bool `json:"pending_approval"` Permissions *UserPermissions `json:"permissions,omitempty"` @@ -1953,6 +2089,12 @@ type PostApiGroupsJSONRequestBody = GroupRequest // PutApiGroupsGroupIdJSONRequestBody defines body for PutApiGroupsGroupId for application/json ContentType. type PutApiGroupsGroupIdJSONRequestBody = GroupRequest +// PostApiIdentityProvidersJSONRequestBody defines body for PostApiIdentityProviders for application/json ContentType. +type PostApiIdentityProvidersJSONRequestBody = IdentityProviderRequest + +// PutApiIdentityProvidersIdpIdJSONRequestBody defines body for PutApiIdentityProvidersIdpId for application/json ContentType. +type PutApiIdentityProvidersIdpIdJSONRequestBody = IdentityProviderRequest + // PostApiIngressPeersJSONRequestBody defines body for PostApiIngressPeers for application/json ContentType. type PostApiIngressPeersJSONRequestBody = IngressPeerCreateRequest @@ -2007,6 +2149,9 @@ type PostApiRoutesJSONRequestBody = RouteRequest // PutApiRoutesRouteIdJSONRequestBody defines body for PutApiRoutesRouteId for application/json ContentType. type PutApiRoutesRouteIdJSONRequestBody = RouteRequest +// PostApiSetupJSONRequestBody defines body for PostApiSetup for application/json ContentType. +type PostApiSetupJSONRequestBody = SetupRequest + // PostApiSetupKeysJSONRequestBody defines body for PostApiSetupKeys for application/json ContentType. type PostApiSetupKeysJSONRequestBody = CreateSetupKeyRequest diff --git a/shared/management/proto/management.pb.go b/shared/management/proto/management.pb.go index 2e4cf2644..2047c51ea 100644 --- a/shared/management/proto/management.pb.go +++ b/shared/management/proto/management.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v6.32.1 +// protoc v6.33.1 // source: management.proto package proto @@ -267,7 +267,7 @@ func (x DeviceAuthorizationFlowProvider) Number() protoreflect.EnumNumber { // Deprecated: Use DeviceAuthorizationFlowProvider.Descriptor instead. func (DeviceAuthorizationFlowProvider) EnumDescriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{24, 0} + return file_management_proto_rawDescGZIP(), []int{27, 0} } type EncryptedMessage struct { @@ -1762,6 +1762,8 @@ type PeerConfig struct { RoutingPeerDnsResolutionEnabled bool `protobuf:"varint,5,opt,name=RoutingPeerDnsResolutionEnabled,proto3" json:"RoutingPeerDnsResolutionEnabled,omitempty"` LazyConnectionEnabled bool `protobuf:"varint,6,opt,name=LazyConnectionEnabled,proto3" json:"LazyConnectionEnabled,omitempty"` Mtu int32 `protobuf:"varint,7,opt,name=mtu,proto3" json:"mtu,omitempty"` + // Auto-update config + AutoUpdate *AutoUpdateSettings `protobuf:"bytes,8,opt,name=autoUpdate,proto3" json:"autoUpdate,omitempty"` } func (x *PeerConfig) Reset() { @@ -1845,6 +1847,70 @@ func (x *PeerConfig) GetMtu() int32 { return 0 } +func (x *PeerConfig) GetAutoUpdate() *AutoUpdateSettings { + if x != nil { + return x.AutoUpdate + } + return nil +} + +type AutoUpdateSettings struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Version string `protobuf:"bytes,1,opt,name=version,proto3" json:"version,omitempty"` + // alwaysUpdate = true → Updates happen automatically in the background + // alwaysUpdate = false → Updates only happen when triggered by a peer connection + AlwaysUpdate bool `protobuf:"varint,2,opt,name=alwaysUpdate,proto3" json:"alwaysUpdate,omitempty"` +} + +func (x *AutoUpdateSettings) Reset() { + *x = AutoUpdateSettings{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[20] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *AutoUpdateSettings) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AutoUpdateSettings) ProtoMessage() {} + +func (x *AutoUpdateSettings) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[20] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AutoUpdateSettings.ProtoReflect.Descriptor instead. +func (*AutoUpdateSettings) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{20} +} + +func (x *AutoUpdateSettings) GetVersion() string { + if x != nil { + return x.Version + } + return "" +} + +func (x *AutoUpdateSettings) GetAlwaysUpdate() bool { + if x != nil { + return x.AlwaysUpdate + } + return false +} + // NetworkMap represents a network state of the peer with the corresponding configuration parameters to establish peer-to-peer connections type NetworkMap struct { state protoimpl.MessageState @@ -1876,12 +1942,14 @@ type NetworkMap struct { // RoutesFirewallRulesIsEmpty indicates whether RouteFirewallRule array is empty or not to bypass protobuf null and empty array equality. RoutesFirewallRulesIsEmpty bool `protobuf:"varint,11,opt,name=routesFirewallRulesIsEmpty,proto3" json:"routesFirewallRulesIsEmpty,omitempty"` ForwardingRules []*ForwardingRule `protobuf:"bytes,12,rep,name=forwardingRules,proto3" json:"forwardingRules,omitempty"` + // SSHAuth represents SSH authorization configuration + SshAuth *SSHAuth `protobuf:"bytes,13,opt,name=sshAuth,proto3" json:"sshAuth,omitempty"` } func (x *NetworkMap) Reset() { *x = NetworkMap{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[20] + mi := &file_management_proto_msgTypes[21] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1894,7 +1962,7 @@ func (x *NetworkMap) String() string { func (*NetworkMap) ProtoMessage() {} func (x *NetworkMap) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[20] + mi := &file_management_proto_msgTypes[21] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1907,7 +1975,7 @@ func (x *NetworkMap) ProtoReflect() protoreflect.Message { // Deprecated: Use NetworkMap.ProtoReflect.Descriptor instead. func (*NetworkMap) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{20} + return file_management_proto_rawDescGZIP(), []int{21} } func (x *NetworkMap) GetSerial() uint64 { @@ -1994,6 +2062,126 @@ func (x *NetworkMap) GetForwardingRules() []*ForwardingRule { return nil } +func (x *NetworkMap) GetSshAuth() *SSHAuth { + if x != nil { + return x.SshAuth + } + return nil +} + +type SSHAuth struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // UserIDClaim is the JWT claim to be used to get the users ID + UserIDClaim string `protobuf:"bytes,1,opt,name=UserIDClaim,proto3" json:"UserIDClaim,omitempty"` + // AuthorizedUsers is a list of hashed user IDs authorized to access this peer via SSH + AuthorizedUsers [][]byte `protobuf:"bytes,2,rep,name=AuthorizedUsers,proto3" json:"AuthorizedUsers,omitempty"` + // MachineUsers is a map of machine user names to their corresponding indexes in the AuthorizedUsers list + MachineUsers map[string]*MachineUserIndexes `protobuf:"bytes,3,rep,name=machine_users,json=machineUsers,proto3" json:"machine_users,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` +} + +func (x *SSHAuth) Reset() { + *x = SSHAuth{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[22] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SSHAuth) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SSHAuth) ProtoMessage() {} + +func (x *SSHAuth) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[22] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SSHAuth.ProtoReflect.Descriptor instead. +func (*SSHAuth) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{22} +} + +func (x *SSHAuth) GetUserIDClaim() string { + if x != nil { + return x.UserIDClaim + } + return "" +} + +func (x *SSHAuth) GetAuthorizedUsers() [][]byte { + if x != nil { + return x.AuthorizedUsers + } + return nil +} + +func (x *SSHAuth) GetMachineUsers() map[string]*MachineUserIndexes { + if x != nil { + return x.MachineUsers + } + return nil +} + +type MachineUserIndexes struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Indexes []uint32 `protobuf:"varint,1,rep,packed,name=indexes,proto3" json:"indexes,omitempty"` +} + +func (x *MachineUserIndexes) Reset() { + *x = MachineUserIndexes{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[23] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *MachineUserIndexes) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MachineUserIndexes) ProtoMessage() {} + +func (x *MachineUserIndexes) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[23] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MachineUserIndexes.ProtoReflect.Descriptor instead. +func (*MachineUserIndexes) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{23} +} + +func (x *MachineUserIndexes) GetIndexes() []uint32 { + if x != nil { + return x.Indexes + } + return nil +} + // RemotePeerConfig represents a configuration of a remote peer. // The properties are used to configure WireGuard Peers sections type RemotePeerConfig struct { @@ -2015,7 +2203,7 @@ type RemotePeerConfig struct { func (x *RemotePeerConfig) Reset() { *x = RemotePeerConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[21] + mi := &file_management_proto_msgTypes[24] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2028,7 +2216,7 @@ func (x *RemotePeerConfig) String() string { func (*RemotePeerConfig) ProtoMessage() {} func (x *RemotePeerConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[21] + mi := &file_management_proto_msgTypes[24] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2041,7 +2229,7 @@ func (x *RemotePeerConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use RemotePeerConfig.ProtoReflect.Descriptor instead. func (*RemotePeerConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{21} + return file_management_proto_rawDescGZIP(), []int{24} } func (x *RemotePeerConfig) GetWgPubKey() string { @@ -2096,7 +2284,7 @@ type SSHConfig struct { func (x *SSHConfig) Reset() { *x = SSHConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[22] + mi := &file_management_proto_msgTypes[25] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2109,7 +2297,7 @@ func (x *SSHConfig) String() string { func (*SSHConfig) ProtoMessage() {} func (x *SSHConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[22] + mi := &file_management_proto_msgTypes[25] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2122,7 +2310,7 @@ func (x *SSHConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use SSHConfig.ProtoReflect.Descriptor instead. func (*SSHConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{22} + return file_management_proto_rawDescGZIP(), []int{25} } func (x *SSHConfig) GetSshEnabled() bool { @@ -2156,7 +2344,7 @@ type DeviceAuthorizationFlowRequest struct { func (x *DeviceAuthorizationFlowRequest) Reset() { *x = DeviceAuthorizationFlowRequest{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[23] + mi := &file_management_proto_msgTypes[26] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2169,7 +2357,7 @@ func (x *DeviceAuthorizationFlowRequest) String() string { func (*DeviceAuthorizationFlowRequest) ProtoMessage() {} func (x *DeviceAuthorizationFlowRequest) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[23] + mi := &file_management_proto_msgTypes[26] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2182,7 +2370,7 @@ func (x *DeviceAuthorizationFlowRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use DeviceAuthorizationFlowRequest.ProtoReflect.Descriptor instead. func (*DeviceAuthorizationFlowRequest) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{23} + return file_management_proto_rawDescGZIP(), []int{26} } // DeviceAuthorizationFlow represents Device Authorization Flow information @@ -2201,7 +2389,7 @@ type DeviceAuthorizationFlow struct { func (x *DeviceAuthorizationFlow) Reset() { *x = DeviceAuthorizationFlow{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[24] + mi := &file_management_proto_msgTypes[27] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2214,7 +2402,7 @@ func (x *DeviceAuthorizationFlow) String() string { func (*DeviceAuthorizationFlow) ProtoMessage() {} func (x *DeviceAuthorizationFlow) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[24] + mi := &file_management_proto_msgTypes[27] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2227,7 +2415,7 @@ func (x *DeviceAuthorizationFlow) ProtoReflect() protoreflect.Message { // Deprecated: Use DeviceAuthorizationFlow.ProtoReflect.Descriptor instead. func (*DeviceAuthorizationFlow) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{24} + return file_management_proto_rawDescGZIP(), []int{27} } func (x *DeviceAuthorizationFlow) GetProvider() DeviceAuthorizationFlowProvider { @@ -2254,7 +2442,7 @@ type PKCEAuthorizationFlowRequest struct { func (x *PKCEAuthorizationFlowRequest) Reset() { *x = PKCEAuthorizationFlowRequest{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[25] + mi := &file_management_proto_msgTypes[28] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2267,7 +2455,7 @@ func (x *PKCEAuthorizationFlowRequest) String() string { func (*PKCEAuthorizationFlowRequest) ProtoMessage() {} func (x *PKCEAuthorizationFlowRequest) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[25] + mi := &file_management_proto_msgTypes[28] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2280,7 +2468,7 @@ func (x *PKCEAuthorizationFlowRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use PKCEAuthorizationFlowRequest.ProtoReflect.Descriptor instead. func (*PKCEAuthorizationFlowRequest) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{25} + return file_management_proto_rawDescGZIP(), []int{28} } // PKCEAuthorizationFlow represents Authorization Code Flow information @@ -2297,7 +2485,7 @@ type PKCEAuthorizationFlow struct { func (x *PKCEAuthorizationFlow) Reset() { *x = PKCEAuthorizationFlow{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[26] + mi := &file_management_proto_msgTypes[29] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2310,7 +2498,7 @@ func (x *PKCEAuthorizationFlow) String() string { func (*PKCEAuthorizationFlow) ProtoMessage() {} func (x *PKCEAuthorizationFlow) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[26] + mi := &file_management_proto_msgTypes[29] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2323,7 +2511,7 @@ func (x *PKCEAuthorizationFlow) ProtoReflect() protoreflect.Message { // Deprecated: Use PKCEAuthorizationFlow.ProtoReflect.Descriptor instead. func (*PKCEAuthorizationFlow) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{26} + return file_management_proto_rawDescGZIP(), []int{29} } func (x *PKCEAuthorizationFlow) GetProviderConfig() *ProviderConfig { @@ -2369,7 +2557,7 @@ type ProviderConfig struct { func (x *ProviderConfig) Reset() { *x = ProviderConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[27] + mi := &file_management_proto_msgTypes[30] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2382,7 +2570,7 @@ func (x *ProviderConfig) String() string { func (*ProviderConfig) ProtoMessage() {} func (x *ProviderConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[27] + mi := &file_management_proto_msgTypes[30] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2395,7 +2583,7 @@ func (x *ProviderConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use ProviderConfig.ProtoReflect.Descriptor instead. func (*ProviderConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{27} + return file_management_proto_rawDescGZIP(), []int{30} } func (x *ProviderConfig) GetClientID() string { @@ -2503,7 +2691,7 @@ type Route struct { func (x *Route) Reset() { *x = Route{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[28] + mi := &file_management_proto_msgTypes[31] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2516,7 +2704,7 @@ func (x *Route) String() string { func (*Route) ProtoMessage() {} func (x *Route) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[28] + mi := &file_management_proto_msgTypes[31] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2529,7 +2717,7 @@ func (x *Route) ProtoReflect() protoreflect.Message { // Deprecated: Use Route.ProtoReflect.Descriptor instead. func (*Route) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{28} + return file_management_proto_rawDescGZIP(), []int{31} } func (x *Route) GetID() string { @@ -2618,7 +2806,7 @@ type DNSConfig struct { func (x *DNSConfig) Reset() { *x = DNSConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[29] + mi := &file_management_proto_msgTypes[32] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2631,7 +2819,7 @@ func (x *DNSConfig) String() string { func (*DNSConfig) ProtoMessage() {} func (x *DNSConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[29] + mi := &file_management_proto_msgTypes[32] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2644,7 +2832,7 @@ func (x *DNSConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use DNSConfig.ProtoReflect.Descriptor instead. func (*DNSConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{29} + return file_management_proto_rawDescGZIP(), []int{32} } func (x *DNSConfig) GetServiceEnable() bool { @@ -2691,7 +2879,7 @@ type CustomZone struct { func (x *CustomZone) Reset() { *x = CustomZone{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[30] + mi := &file_management_proto_msgTypes[33] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2704,7 +2892,7 @@ func (x *CustomZone) String() string { func (*CustomZone) ProtoMessage() {} func (x *CustomZone) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[30] + mi := &file_management_proto_msgTypes[33] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2717,7 +2905,7 @@ func (x *CustomZone) ProtoReflect() protoreflect.Message { // Deprecated: Use CustomZone.ProtoReflect.Descriptor instead. func (*CustomZone) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{30} + return file_management_proto_rawDescGZIP(), []int{33} } func (x *CustomZone) GetDomain() string { @@ -2764,7 +2952,7 @@ type SimpleRecord struct { func (x *SimpleRecord) Reset() { *x = SimpleRecord{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[31] + mi := &file_management_proto_msgTypes[34] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2777,7 +2965,7 @@ func (x *SimpleRecord) String() string { func (*SimpleRecord) ProtoMessage() {} func (x *SimpleRecord) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[31] + mi := &file_management_proto_msgTypes[34] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2790,7 +2978,7 @@ func (x *SimpleRecord) ProtoReflect() protoreflect.Message { // Deprecated: Use SimpleRecord.ProtoReflect.Descriptor instead. func (*SimpleRecord) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{31} + return file_management_proto_rawDescGZIP(), []int{34} } func (x *SimpleRecord) GetName() string { @@ -2843,7 +3031,7 @@ type NameServerGroup struct { func (x *NameServerGroup) Reset() { *x = NameServerGroup{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[32] + mi := &file_management_proto_msgTypes[35] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2856,7 +3044,7 @@ func (x *NameServerGroup) String() string { func (*NameServerGroup) ProtoMessage() {} func (x *NameServerGroup) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[32] + mi := &file_management_proto_msgTypes[35] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2869,7 +3057,7 @@ func (x *NameServerGroup) ProtoReflect() protoreflect.Message { // Deprecated: Use NameServerGroup.ProtoReflect.Descriptor instead. func (*NameServerGroup) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{32} + return file_management_proto_rawDescGZIP(), []int{35} } func (x *NameServerGroup) GetNameServers() []*NameServer { @@ -2914,7 +3102,7 @@ type NameServer struct { func (x *NameServer) Reset() { *x = NameServer{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[33] + mi := &file_management_proto_msgTypes[36] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2927,7 +3115,7 @@ func (x *NameServer) String() string { func (*NameServer) ProtoMessage() {} func (x *NameServer) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[33] + mi := &file_management_proto_msgTypes[36] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2940,7 +3128,7 @@ func (x *NameServer) ProtoReflect() protoreflect.Message { // Deprecated: Use NameServer.ProtoReflect.Descriptor instead. func (*NameServer) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{33} + return file_management_proto_rawDescGZIP(), []int{36} } func (x *NameServer) GetIP() string { @@ -2983,7 +3171,7 @@ type FirewallRule struct { func (x *FirewallRule) Reset() { *x = FirewallRule{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[34] + mi := &file_management_proto_msgTypes[37] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2996,7 +3184,7 @@ func (x *FirewallRule) String() string { func (*FirewallRule) ProtoMessage() {} func (x *FirewallRule) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[34] + mi := &file_management_proto_msgTypes[37] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3009,7 +3197,7 @@ func (x *FirewallRule) ProtoReflect() protoreflect.Message { // Deprecated: Use FirewallRule.ProtoReflect.Descriptor instead. func (*FirewallRule) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{34} + return file_management_proto_rawDescGZIP(), []int{37} } func (x *FirewallRule) GetPeerIP() string { @@ -3073,7 +3261,7 @@ type NetworkAddress struct { func (x *NetworkAddress) Reset() { *x = NetworkAddress{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[35] + mi := &file_management_proto_msgTypes[38] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3086,7 +3274,7 @@ func (x *NetworkAddress) String() string { func (*NetworkAddress) ProtoMessage() {} func (x *NetworkAddress) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[35] + mi := &file_management_proto_msgTypes[38] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3099,7 +3287,7 @@ func (x *NetworkAddress) ProtoReflect() protoreflect.Message { // Deprecated: Use NetworkAddress.ProtoReflect.Descriptor instead. func (*NetworkAddress) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{35} + return file_management_proto_rawDescGZIP(), []int{38} } func (x *NetworkAddress) GetNetIP() string { @@ -3127,7 +3315,7 @@ type Checks struct { func (x *Checks) Reset() { *x = Checks{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[36] + mi := &file_management_proto_msgTypes[39] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3140,7 +3328,7 @@ func (x *Checks) String() string { func (*Checks) ProtoMessage() {} func (x *Checks) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[36] + mi := &file_management_proto_msgTypes[39] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3153,7 +3341,7 @@ func (x *Checks) ProtoReflect() protoreflect.Message { // Deprecated: Use Checks.ProtoReflect.Descriptor instead. func (*Checks) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{36} + return file_management_proto_rawDescGZIP(), []int{39} } func (x *Checks) GetFiles() []string { @@ -3178,7 +3366,7 @@ type PortInfo struct { func (x *PortInfo) Reset() { *x = PortInfo{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[37] + mi := &file_management_proto_msgTypes[40] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3191,7 +3379,7 @@ func (x *PortInfo) String() string { func (*PortInfo) ProtoMessage() {} func (x *PortInfo) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[37] + mi := &file_management_proto_msgTypes[40] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3204,7 +3392,7 @@ func (x *PortInfo) ProtoReflect() protoreflect.Message { // Deprecated: Use PortInfo.ProtoReflect.Descriptor instead. func (*PortInfo) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{37} + return file_management_proto_rawDescGZIP(), []int{40} } func (m *PortInfo) GetPortSelection() isPortInfo_PortSelection { @@ -3275,7 +3463,7 @@ type RouteFirewallRule struct { func (x *RouteFirewallRule) Reset() { *x = RouteFirewallRule{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[38] + mi := &file_management_proto_msgTypes[41] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3288,7 +3476,7 @@ func (x *RouteFirewallRule) String() string { func (*RouteFirewallRule) ProtoMessage() {} func (x *RouteFirewallRule) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[38] + mi := &file_management_proto_msgTypes[41] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3301,7 +3489,7 @@ func (x *RouteFirewallRule) ProtoReflect() protoreflect.Message { // Deprecated: Use RouteFirewallRule.ProtoReflect.Descriptor instead. func (*RouteFirewallRule) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{38} + return file_management_proto_rawDescGZIP(), []int{41} } func (x *RouteFirewallRule) GetSourceRanges() []string { @@ -3392,7 +3580,7 @@ type ForwardingRule struct { func (x *ForwardingRule) Reset() { *x = ForwardingRule{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[39] + mi := &file_management_proto_msgTypes[42] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3405,7 +3593,7 @@ func (x *ForwardingRule) String() string { func (*ForwardingRule) ProtoMessage() {} func (x *ForwardingRule) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[39] + mi := &file_management_proto_msgTypes[42] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3418,7 +3606,7 @@ func (x *ForwardingRule) ProtoReflect() protoreflect.Message { // Deprecated: Use ForwardingRule.ProtoReflect.Descriptor instead. func (*ForwardingRule) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{39} + return file_management_proto_rawDescGZIP(), []int{42} } func (x *ForwardingRule) GetProtocol() RuleProtocol { @@ -3461,7 +3649,7 @@ type PortInfo_Range struct { func (x *PortInfo_Range) Reset() { *x = PortInfo_Range{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[40] + mi := &file_management_proto_msgTypes[44] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3474,7 +3662,7 @@ func (x *PortInfo_Range) String() string { func (*PortInfo_Range) ProtoMessage() {} func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[40] + mi := &file_management_proto_msgTypes[44] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3487,7 +3675,7 @@ func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { // Deprecated: Use PortInfo_Range.ProtoReflect.Descriptor instead. func (*PortInfo_Range) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{37, 0} + return file_management_proto_rawDescGZIP(), []int{40, 0} } func (x *PortInfo_Range) GetStart() uint32 { @@ -3747,7 +3935,7 @@ var file_management_proto_rawDesc = []byte{ 0x12, 0x12, 0x0a, 0x04, 0x75, 0x73, 0x65, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x75, 0x73, 0x65, 0x72, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, - 0x22, 0x93, 0x02, 0x0a, 0x0a, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, + 0x22, 0xd3, 0x02, 0x0a, 0x0a, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x64, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x64, 0x6e, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, @@ -3764,308 +3952,339 @@ var file_management_proto_rawDesc = []byte{ 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x15, 0x4c, 0x61, 0x7a, 0x79, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x74, 0x75, 0x18, 0x07, 0x20, 0x01, 0x28, - 0x05, 0x52, 0x03, 0x6d, 0x74, 0x75, 0x22, 0xb9, 0x05, 0x0a, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, - 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x12, 0x36, 0x0a, - 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, - 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x3e, 0x0a, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, - 0x65, 0x65, 0x72, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, - 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, - 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, - 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, - 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x29, 0x0a, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, - 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, - 0x12, 0x33, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x06, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x44, 0x4e, 0x53, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x40, 0x0a, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, - 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, - 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, - 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x3e, 0x0a, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, - 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, - 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, - 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, - 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, - 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, - 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x4f, 0x0a, 0x13, 0x72, - 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, - 0x65, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, - 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, - 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x3e, 0x0a, 0x1a, - 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, - 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, - 0x52, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, - 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x44, 0x0a, 0x0f, - 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, - 0x0c, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, - 0x65, 0x52, 0x0f, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, - 0x65, 0x73, 0x22, 0xbb, 0x01, 0x0a, 0x10, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, - 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, - 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, - 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, - 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, - 0x49, 0x70, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x05, 0x52, 0x03, 0x6d, 0x74, 0x75, 0x12, 0x3e, 0x0a, 0x0a, 0x61, 0x75, 0x74, 0x6f, 0x55, 0x70, + 0x64, 0x61, 0x74, 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x75, 0x74, 0x6f, 0x55, 0x70, 0x64, 0x61, + 0x74, 0x65, 0x53, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x52, 0x0a, 0x61, 0x75, 0x74, 0x6f, + 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x22, 0x52, 0x0a, 0x12, 0x41, 0x75, 0x74, 0x6f, 0x55, 0x70, + 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x12, 0x18, 0x0a, 0x07, + 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, + 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x22, 0x0a, 0x0c, 0x61, 0x6c, 0x77, 0x61, 0x79, 0x73, + 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0c, 0x61, 0x6c, + 0x77, 0x61, 0x79, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x22, 0xe8, 0x05, 0x0a, 0x0a, 0x4e, + 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x53, 0x65, 0x72, + 0x69, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, + 0x6c, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x70, + 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x3e, 0x0a, 0x0b, 0x72, 0x65, 0x6d, + 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, + 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0b, 0x72, 0x65, + 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6d, + 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, + 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, + 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x29, 0x0a, 0x06, 0x52, 0x6f, 0x75, + 0x74, 0x65, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x52, 0x6f, + 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, + 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x40, 0x0a, 0x0c, 0x6f, 0x66, 0x66, + 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x0b, 0x32, + 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, + 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0c, 0x6f, + 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x3e, 0x0a, 0x0d, 0x46, + 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x08, 0x20, 0x03, + 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0d, 0x46, 0x69, + 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x66, + 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, + 0x70, 0x74, 0x79, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, + 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, + 0x4f, 0x0a, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, + 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, + 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x13, 0x72, 0x6f, 0x75, + 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, + 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, + 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x0b, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, + 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, + 0x12, 0x44, 0x0a, 0x0f, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, + 0x6c, 0x65, 0x73, 0x18, 0x0c, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, + 0x67, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0f, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, + 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x2d, 0x0a, 0x07, 0x73, 0x73, 0x68, 0x41, 0x75, 0x74, + 0x68, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x41, 0x75, 0x74, 0x68, 0x52, 0x07, 0x73, 0x73, + 0x68, 0x41, 0x75, 0x74, 0x68, 0x22, 0x82, 0x02, 0x0a, 0x07, 0x53, 0x53, 0x48, 0x41, 0x75, 0x74, + 0x68, 0x12, 0x20, 0x0a, 0x0b, 0x55, 0x73, 0x65, 0x72, 0x49, 0x44, 0x43, 0x6c, 0x61, 0x69, 0x6d, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x55, 0x73, 0x65, 0x72, 0x49, 0x44, 0x43, 0x6c, + 0x61, 0x69, 0x6d, 0x12, 0x28, 0x0a, 0x0f, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, + 0x64, 0x55, 0x73, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0c, 0x52, 0x0f, 0x41, 0x75, + 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x64, 0x55, 0x73, 0x65, 0x72, 0x73, 0x12, 0x4a, 0x0a, + 0x0d, 0x6d, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x65, 0x5f, 0x75, 0x73, 0x65, 0x72, 0x73, 0x18, 0x03, + 0x20, 0x03, 0x28, 0x0b, 0x32, 0x25, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x53, 0x53, 0x48, 0x41, 0x75, 0x74, 0x68, 0x2e, 0x4d, 0x61, 0x63, 0x68, 0x69, 0x6e, + 0x65, 0x55, 0x73, 0x65, 0x72, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0c, 0x6d, 0x61, 0x63, + 0x68, 0x69, 0x6e, 0x65, 0x55, 0x73, 0x65, 0x72, 0x73, 0x1a, 0x5f, 0x0a, 0x11, 0x4d, 0x61, 0x63, + 0x68, 0x69, 0x6e, 0x65, 0x55, 0x73, 0x65, 0x72, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, + 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, + 0x12, 0x34, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x1e, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4d, 0x61, 0x63, + 0x68, 0x69, 0x6e, 0x65, 0x55, 0x73, 0x65, 0x72, 0x49, 0x6e, 0x64, 0x65, 0x78, 0x65, 0x73, 0x52, + 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x2e, 0x0a, 0x12, 0x4d, 0x61, + 0x63, 0x68, 0x69, 0x6e, 0x65, 0x55, 0x73, 0x65, 0x72, 0x49, 0x6e, 0x64, 0x65, 0x78, 0x65, 0x73, + 0x12, 0x18, 0x0a, 0x07, 0x69, 0x6e, 0x64, 0x65, 0x78, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, + 0x0d, 0x52, 0x07, 0x69, 0x6e, 0x64, 0x65, 0x78, 0x65, 0x73, 0x22, 0xbb, 0x01, 0x0a, 0x10, 0x52, + 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, + 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x61, + 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, + 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, + 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, + 0x66, 0x71, 0x64, 0x6e, 0x12, 0x22, 0x0a, 0x0c, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x56, 0x65, 0x72, + 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x61, 0x67, 0x65, 0x6e, + 0x74, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x7e, 0x0a, 0x09, 0x53, 0x53, 0x48, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, + 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, + 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, + 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, + 0x4b, 0x65, 0x79, 0x12, 0x33, 0x0a, 0x09, 0x6a, 0x77, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, - 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, - 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x22, 0x0a, 0x0c, - 0x61, 0x67, 0x65, 0x6e, 0x74, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x05, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x0c, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, - 0x22, 0x7e, 0x0a, 0x09, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1e, 0x0a, - 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1c, 0x0a, - 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, - 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x33, 0x0a, 0x09, 0x6a, - 0x77, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4a, 0x57, 0x54, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x6a, 0x77, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, - 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, - 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x48, - 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x65, - 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, - 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x52, 0x08, - 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, - 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, - 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, - 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, - 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x16, 0x0a, 0x08, - 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, 0x0a, 0x06, 0x48, 0x4f, 0x53, 0x54, - 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, 0x1c, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, - 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, 0x15, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, - 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x42, 0x0a, - 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x22, 0xb8, 0x03, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, - 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, - 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, - 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, - 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, - 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x65, 0x76, 0x69, - 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x05, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, - 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, 0x0a, 0x0d, 0x54, 0x6f, 0x6b, 0x65, - 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x14, - 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x53, - 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, - 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, - 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, - 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x09, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, - 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x22, 0x0a, 0x0c, 0x52, 0x65, - 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, - 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x12, 0x2e, - 0x0a, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x4c, - 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x44, 0x69, 0x73, 0x61, - 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, - 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, 0x61, 0x67, 0x18, 0x0c, 0x20, 0x01, 0x28, - 0x0d, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, 0x61, 0x67, 0x22, 0x93, 0x02, 0x0a, - 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, - 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, - 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, - 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, - 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, 0x1e, - 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x18, 0x06, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x12, 0x14, - 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x4e, - 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, - 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x1c, - 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x24, 0x0a, 0x0d, - 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x18, 0x0a, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70, - 0x6c, 0x79, 0x22, 0xde, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, - 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, - 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, - 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, - 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, - 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, 0x10, 0x4e, - 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, - 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x18, 0x03, - 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, 0x43, 0x75, - 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x12, 0x28, 0x0a, 0x0d, 0x46, 0x6f, 0x72, - 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, - 0x42, 0x02, 0x18, 0x01, 0x52, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, - 0x6f, 0x72, 0x74, 0x22, 0xb4, 0x01, 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, - 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, - 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, - 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x12, 0x32, - 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x44, 0x69, - 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, - 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, - 0x65, 0x64, 0x12, 0x26, 0x0a, 0x0e, 0x53, 0x6b, 0x69, 0x70, 0x50, 0x54, 0x52, 0x50, 0x72, 0x6f, - 0x63, 0x65, 0x73, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0e, 0x53, 0x6b, 0x69, 0x70, - 0x50, 0x54, 0x52, 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, - 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, - 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, - 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, - 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, - 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, - 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, - 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, - 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, - 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, - 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, - 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, - 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, - 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, - 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, - 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, - 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, - 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, - 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, - 0x22, 0xa7, 0x02, 0x0a, 0x0c, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, - 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x12, 0x37, 0x0a, 0x09, 0x44, 0x69, 0x72, - 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, - 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, - 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, - 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, + 0x65, 0x6e, 0x74, 0x2e, 0x4a, 0x57, 0x54, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x6a, + 0x77, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, + 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, + 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, + 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x48, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, + 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, + 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, + 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, + 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, + 0x12, 0x0a, 0x0a, 0x06, 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, 0x1c, + 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, + 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, 0x15, + 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, + 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, + 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, + 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, + 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0xb8, 0x03, 0x0a, 0x0e, 0x50, 0x72, + 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, + 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, + 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, + 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, + 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, + 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, + 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, + 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, + 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, + 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, + 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, + 0x12, 0x24, 0x0a, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, + 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, + 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, + 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, + 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, + 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, + 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, + 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, + 0x6e, 0x74, 0x12, 0x22, 0x0a, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, + 0x4c, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, + 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, + 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x0b, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, + 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, + 0x6c, 0x61, 0x67, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, + 0x46, 0x6c, 0x61, 0x67, 0x22, 0x93, 0x02, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, + 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, + 0x0a, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, + 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, + 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, + 0x65, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, + 0x0a, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, + 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, + 0x72, 0x61, 0x64, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, + 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, + 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, + 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, + 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, + 0x75, 0x74, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, + 0x6f, 0x75, 0x74, 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, + 0x41, 0x70, 0x70, 0x6c, 0x79, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, 0x69, + 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x22, 0xde, 0x01, 0x0a, 0x09, 0x44, + 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, + 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, + 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, + 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, + 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, + 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, + 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, + 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, + 0x73, 0x12, 0x28, 0x0a, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, + 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x42, 0x02, 0x18, 0x01, 0x52, 0x0d, 0x46, 0x6f, + 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xb4, 0x01, 0x0a, 0x0a, + 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, + 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, + 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, + 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, + 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x26, 0x0a, 0x0e, 0x53, 0x6b, + 0x69, 0x70, 0x50, 0x54, 0x52, 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x0e, 0x53, 0x6b, 0x69, 0x70, 0x50, 0x54, 0x52, 0x50, 0x72, 0x6f, 0x63, 0x65, + 0x73, 0x73, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, + 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, + 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, + 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, + 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, + 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, + 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, + 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, + 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, + 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, + 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, + 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, + 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, + 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, + 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, + 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48, + 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, + 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, + 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, + 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, + 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xa7, 0x02, 0x0a, 0x0c, 0x46, 0x69, 0x72, + 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, + 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, + 0x50, 0x12, 0x37, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, + 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63, + 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, + 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72, + 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, + 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, + 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, + 0x50, 0x6f, 0x72, 0x74, 0x12, 0x30, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, + 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x50, 0x6f, + 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, + 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, + 0x49, 0x44, 0x22, 0x38, 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, + 0x72, 0x65, 0x73, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, + 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, + 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, + 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a, + 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72, + 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, + 0x32, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, + 0x49, 0x6e, 0x66, 0x6f, 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, 0x61, + 0x6e, 0x67, 0x65, 0x1a, 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, + 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, + 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, + 0x03, 0x65, 0x6e, 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65, + 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x87, 0x03, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, + 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, + 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, + 0x09, 0x52, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, + 0x2e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, + 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, + 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, + 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, + 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, + 0x6e, 0x66, 0x6f, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, + 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, + 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, + 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, + 0x73, 0x12, 0x26, 0x0a, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, + 0x63, 0x6f, 0x6c, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, + 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, + 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, + 0x69, 0x63, 0x79, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, + 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x22, + 0xf2, 0x01, 0x0a, 0x0e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, + 0x6c, 0x65, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, - 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, - 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x30, 0x0a, 0x08, - 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, - 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1a, - 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, - 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x22, 0x38, 0x0a, 0x0e, 0x4e, 0x65, - 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x14, 0x0a, 0x05, - 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6e, 0x65, 0x74, - 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x12, 0x14, - 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x05, 0x46, - 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, - 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x48, - 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x32, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x2e, 0x52, 0x61, 0x6e, - 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x1a, 0x2f, 0x0a, 0x05, 0x52, - 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, - 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, 0x64, 0x42, 0x0f, 0x0a, 0x0d, - 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x87, 0x03, - 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, - 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, - 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, - 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, - 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, - 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, - 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, - 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, - 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, - 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x05, 0x20, 0x01, 0x28, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3e, 0x0a, 0x0f, 0x64, 0x65, 0x73, 0x74, + 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, - 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, - 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x18, 0x06, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x12, - 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, - 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x26, 0x0a, 0x0e, 0x63, 0x75, 0x73, - 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x08, 0x20, 0x01, 0x28, - 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, - 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x09, 0x20, - 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x12, 0x18, 0x0a, - 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, - 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x22, 0xf2, 0x01, 0x0a, 0x0e, 0x46, 0x6f, 0x72, 0x77, - 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, - 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, - 0x12, 0x3e, 0x0a, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, - 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, - 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, - 0x12, 0x2c, 0x0a, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, - 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x11, 0x74, 0x72, 0x61, - 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x3c, - 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, - 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0e, 0x74, 0x72, - 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x2a, 0x4c, 0x0a, 0x0c, - 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, - 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, - 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, - 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x12, 0x0a, - 0x0a, 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20, 0x0a, 0x0d, 0x52, 0x75, - 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, 0x02, 0x49, - 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, 0x0a, 0x0a, - 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, - 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, - 0x32, 0xcd, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, - 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, - 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, - 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, - 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, - 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, + 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x74, 0x72, 0x61, 0x6e, + 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x0c, 0x52, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, + 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x3c, 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, + 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, + 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, + 0x50, 0x6f, 0x72, 0x74, 0x2a, 0x4c, 0x0a, 0x0c, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, + 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, + 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, + 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, + 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x12, 0x0a, 0x0a, 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, + 0x10, 0x05, 0x2a, 0x20, 0x0a, 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, + 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, + 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, + 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, + 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x32, 0xcd, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, + 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, - 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, - 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, - 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, - 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, - 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, - 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, - 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, - 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, 0x0a, 0x18, 0x47, 0x65, - 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, - 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, - 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, - 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x74, 0x61, - 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, - 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, - 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, 0x6f, 0x67, 0x6f, 0x75, 0x74, 0x12, 0x1c, 0x2e, + 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, - 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, - 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x33, + 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, + 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, + 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, + 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, + 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, + 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, + 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, + 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, + 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x22, 0x00, 0x12, 0x58, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, + 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, + 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, + 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, + 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, + 0x6f, 0x67, 0x6f, 0x75, 0x74, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, + 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -4081,7 +4300,7 @@ func file_management_proto_rawDescGZIP() []byte { } var file_management_proto_enumTypes = make([]protoimpl.EnumInfo, 5) -var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 41) +var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 45) var file_management_proto_goTypes = []interface{}{ (RuleProtocol)(0), // 0: management.RuleProtocol (RuleDirection)(0), // 1: management.RuleDirection @@ -4108,106 +4327,114 @@ var file_management_proto_goTypes = []interface{}{ (*JWTConfig)(nil), // 22: management.JWTConfig (*ProtectedHostConfig)(nil), // 23: management.ProtectedHostConfig (*PeerConfig)(nil), // 24: management.PeerConfig - (*NetworkMap)(nil), // 25: management.NetworkMap - (*RemotePeerConfig)(nil), // 26: management.RemotePeerConfig - (*SSHConfig)(nil), // 27: management.SSHConfig - (*DeviceAuthorizationFlowRequest)(nil), // 28: management.DeviceAuthorizationFlowRequest - (*DeviceAuthorizationFlow)(nil), // 29: management.DeviceAuthorizationFlow - (*PKCEAuthorizationFlowRequest)(nil), // 30: management.PKCEAuthorizationFlowRequest - (*PKCEAuthorizationFlow)(nil), // 31: management.PKCEAuthorizationFlow - (*ProviderConfig)(nil), // 32: management.ProviderConfig - (*Route)(nil), // 33: management.Route - (*DNSConfig)(nil), // 34: management.DNSConfig - (*CustomZone)(nil), // 35: management.CustomZone - (*SimpleRecord)(nil), // 36: management.SimpleRecord - (*NameServerGroup)(nil), // 37: management.NameServerGroup - (*NameServer)(nil), // 38: management.NameServer - (*FirewallRule)(nil), // 39: management.FirewallRule - (*NetworkAddress)(nil), // 40: management.NetworkAddress - (*Checks)(nil), // 41: management.Checks - (*PortInfo)(nil), // 42: management.PortInfo - (*RouteFirewallRule)(nil), // 43: management.RouteFirewallRule - (*ForwardingRule)(nil), // 44: management.ForwardingRule - (*PortInfo_Range)(nil), // 45: management.PortInfo.Range - (*timestamppb.Timestamp)(nil), // 46: google.protobuf.Timestamp - (*durationpb.Duration)(nil), // 47: google.protobuf.Duration + (*AutoUpdateSettings)(nil), // 25: management.AutoUpdateSettings + (*NetworkMap)(nil), // 26: management.NetworkMap + (*SSHAuth)(nil), // 27: management.SSHAuth + (*MachineUserIndexes)(nil), // 28: management.MachineUserIndexes + (*RemotePeerConfig)(nil), // 29: management.RemotePeerConfig + (*SSHConfig)(nil), // 30: management.SSHConfig + (*DeviceAuthorizationFlowRequest)(nil), // 31: management.DeviceAuthorizationFlowRequest + (*DeviceAuthorizationFlow)(nil), // 32: management.DeviceAuthorizationFlow + (*PKCEAuthorizationFlowRequest)(nil), // 33: management.PKCEAuthorizationFlowRequest + (*PKCEAuthorizationFlow)(nil), // 34: management.PKCEAuthorizationFlow + (*ProviderConfig)(nil), // 35: management.ProviderConfig + (*Route)(nil), // 36: management.Route + (*DNSConfig)(nil), // 37: management.DNSConfig + (*CustomZone)(nil), // 38: management.CustomZone + (*SimpleRecord)(nil), // 39: management.SimpleRecord + (*NameServerGroup)(nil), // 40: management.NameServerGroup + (*NameServer)(nil), // 41: management.NameServer + (*FirewallRule)(nil), // 42: management.FirewallRule + (*NetworkAddress)(nil), // 43: management.NetworkAddress + (*Checks)(nil), // 44: management.Checks + (*PortInfo)(nil), // 45: management.PortInfo + (*RouteFirewallRule)(nil), // 46: management.RouteFirewallRule + (*ForwardingRule)(nil), // 47: management.ForwardingRule + nil, // 48: management.SSHAuth.MachineUsersEntry + (*PortInfo_Range)(nil), // 49: management.PortInfo.Range + (*timestamppb.Timestamp)(nil), // 50: google.protobuf.Timestamp + (*durationpb.Duration)(nil), // 51: google.protobuf.Duration } var file_management_proto_depIdxs = []int32{ 14, // 0: management.SyncRequest.meta:type_name -> management.PeerSystemMeta 18, // 1: management.SyncResponse.netbirdConfig:type_name -> management.NetbirdConfig 24, // 2: management.SyncResponse.peerConfig:type_name -> management.PeerConfig - 26, // 3: management.SyncResponse.remotePeers:type_name -> management.RemotePeerConfig - 25, // 4: management.SyncResponse.NetworkMap:type_name -> management.NetworkMap - 41, // 5: management.SyncResponse.Checks:type_name -> management.Checks + 29, // 3: management.SyncResponse.remotePeers:type_name -> management.RemotePeerConfig + 26, // 4: management.SyncResponse.NetworkMap:type_name -> management.NetworkMap + 44, // 5: management.SyncResponse.Checks:type_name -> management.Checks 14, // 6: management.SyncMetaRequest.meta:type_name -> management.PeerSystemMeta 14, // 7: management.LoginRequest.meta:type_name -> management.PeerSystemMeta 10, // 8: management.LoginRequest.peerKeys:type_name -> management.PeerKeys - 40, // 9: management.PeerSystemMeta.networkAddresses:type_name -> management.NetworkAddress + 43, // 9: management.PeerSystemMeta.networkAddresses:type_name -> management.NetworkAddress 11, // 10: management.PeerSystemMeta.environment:type_name -> management.Environment 12, // 11: management.PeerSystemMeta.files:type_name -> management.File 13, // 12: management.PeerSystemMeta.flags:type_name -> management.Flags 18, // 13: management.LoginResponse.netbirdConfig:type_name -> management.NetbirdConfig 24, // 14: management.LoginResponse.peerConfig:type_name -> management.PeerConfig - 41, // 15: management.LoginResponse.Checks:type_name -> management.Checks - 46, // 16: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp + 44, // 15: management.LoginResponse.Checks:type_name -> management.Checks + 50, // 16: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp 19, // 17: management.NetbirdConfig.stuns:type_name -> management.HostConfig 23, // 18: management.NetbirdConfig.turns:type_name -> management.ProtectedHostConfig 19, // 19: management.NetbirdConfig.signal:type_name -> management.HostConfig 20, // 20: management.NetbirdConfig.relay:type_name -> management.RelayConfig 21, // 21: management.NetbirdConfig.flow:type_name -> management.FlowConfig 3, // 22: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol - 47, // 23: management.FlowConfig.interval:type_name -> google.protobuf.Duration + 51, // 23: management.FlowConfig.interval:type_name -> google.protobuf.Duration 19, // 24: management.ProtectedHostConfig.hostConfig:type_name -> management.HostConfig - 27, // 25: management.PeerConfig.sshConfig:type_name -> management.SSHConfig - 24, // 26: management.NetworkMap.peerConfig:type_name -> management.PeerConfig - 26, // 27: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig - 33, // 28: management.NetworkMap.Routes:type_name -> management.Route - 34, // 29: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig - 26, // 30: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig - 39, // 31: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule - 43, // 32: management.NetworkMap.routesFirewallRules:type_name -> management.RouteFirewallRule - 44, // 33: management.NetworkMap.forwardingRules:type_name -> management.ForwardingRule - 27, // 34: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig - 22, // 35: management.SSHConfig.jwtConfig:type_name -> management.JWTConfig - 4, // 36: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider - 32, // 37: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig - 32, // 38: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig - 37, // 39: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup - 35, // 40: management.DNSConfig.CustomZones:type_name -> management.CustomZone - 36, // 41: management.CustomZone.Records:type_name -> management.SimpleRecord - 38, // 42: management.NameServerGroup.NameServers:type_name -> management.NameServer - 1, // 43: management.FirewallRule.Direction:type_name -> management.RuleDirection - 2, // 44: management.FirewallRule.Action:type_name -> management.RuleAction - 0, // 45: management.FirewallRule.Protocol:type_name -> management.RuleProtocol - 42, // 46: management.FirewallRule.PortInfo:type_name -> management.PortInfo - 45, // 47: management.PortInfo.range:type_name -> management.PortInfo.Range - 2, // 48: management.RouteFirewallRule.action:type_name -> management.RuleAction - 0, // 49: management.RouteFirewallRule.protocol:type_name -> management.RuleProtocol - 42, // 50: management.RouteFirewallRule.portInfo:type_name -> management.PortInfo - 0, // 51: management.ForwardingRule.protocol:type_name -> management.RuleProtocol - 42, // 52: management.ForwardingRule.destinationPort:type_name -> management.PortInfo - 42, // 53: management.ForwardingRule.translatedPort:type_name -> management.PortInfo - 5, // 54: management.ManagementService.Login:input_type -> management.EncryptedMessage - 5, // 55: management.ManagementService.Sync:input_type -> management.EncryptedMessage - 17, // 56: management.ManagementService.GetServerKey:input_type -> management.Empty - 17, // 57: management.ManagementService.isHealthy:input_type -> management.Empty - 5, // 58: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage - 5, // 59: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage - 5, // 60: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage - 5, // 61: management.ManagementService.Logout:input_type -> management.EncryptedMessage - 5, // 62: management.ManagementService.Login:output_type -> management.EncryptedMessage - 5, // 63: management.ManagementService.Sync:output_type -> management.EncryptedMessage - 16, // 64: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse - 17, // 65: management.ManagementService.isHealthy:output_type -> management.Empty - 5, // 66: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage - 5, // 67: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage - 17, // 68: management.ManagementService.SyncMeta:output_type -> management.Empty - 17, // 69: management.ManagementService.Logout:output_type -> management.Empty - 62, // [62:70] is the sub-list for method output_type - 54, // [54:62] is the sub-list for method input_type - 54, // [54:54] is the sub-list for extension type_name - 54, // [54:54] is the sub-list for extension extendee - 0, // [0:54] is the sub-list for field type_name + 30, // 25: management.PeerConfig.sshConfig:type_name -> management.SSHConfig + 25, // 26: management.PeerConfig.autoUpdate:type_name -> management.AutoUpdateSettings + 24, // 27: management.NetworkMap.peerConfig:type_name -> management.PeerConfig + 29, // 28: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig + 36, // 29: management.NetworkMap.Routes:type_name -> management.Route + 37, // 30: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig + 29, // 31: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig + 42, // 32: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule + 46, // 33: management.NetworkMap.routesFirewallRules:type_name -> management.RouteFirewallRule + 47, // 34: management.NetworkMap.forwardingRules:type_name -> management.ForwardingRule + 27, // 35: management.NetworkMap.sshAuth:type_name -> management.SSHAuth + 48, // 36: management.SSHAuth.machine_users:type_name -> management.SSHAuth.MachineUsersEntry + 30, // 37: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig + 22, // 38: management.SSHConfig.jwtConfig:type_name -> management.JWTConfig + 4, // 39: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider + 35, // 40: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig + 35, // 41: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig + 40, // 42: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup + 38, // 43: management.DNSConfig.CustomZones:type_name -> management.CustomZone + 39, // 44: management.CustomZone.Records:type_name -> management.SimpleRecord + 41, // 45: management.NameServerGroup.NameServers:type_name -> management.NameServer + 1, // 46: management.FirewallRule.Direction:type_name -> management.RuleDirection + 2, // 47: management.FirewallRule.Action:type_name -> management.RuleAction + 0, // 48: management.FirewallRule.Protocol:type_name -> management.RuleProtocol + 45, // 49: management.FirewallRule.PortInfo:type_name -> management.PortInfo + 49, // 50: management.PortInfo.range:type_name -> management.PortInfo.Range + 2, // 51: management.RouteFirewallRule.action:type_name -> management.RuleAction + 0, // 52: management.RouteFirewallRule.protocol:type_name -> management.RuleProtocol + 45, // 53: management.RouteFirewallRule.portInfo:type_name -> management.PortInfo + 0, // 54: management.ForwardingRule.protocol:type_name -> management.RuleProtocol + 45, // 55: management.ForwardingRule.destinationPort:type_name -> management.PortInfo + 45, // 56: management.ForwardingRule.translatedPort:type_name -> management.PortInfo + 28, // 57: management.SSHAuth.MachineUsersEntry.value:type_name -> management.MachineUserIndexes + 5, // 58: management.ManagementService.Login:input_type -> management.EncryptedMessage + 5, // 59: management.ManagementService.Sync:input_type -> management.EncryptedMessage + 17, // 60: management.ManagementService.GetServerKey:input_type -> management.Empty + 17, // 61: management.ManagementService.isHealthy:input_type -> management.Empty + 5, // 62: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage + 5, // 63: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage + 5, // 64: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage + 5, // 65: management.ManagementService.Logout:input_type -> management.EncryptedMessage + 5, // 66: management.ManagementService.Login:output_type -> management.EncryptedMessage + 5, // 67: management.ManagementService.Sync:output_type -> management.EncryptedMessage + 16, // 68: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse + 17, // 69: management.ManagementService.isHealthy:output_type -> management.Empty + 5, // 70: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage + 5, // 71: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage + 17, // 72: management.ManagementService.SyncMeta:output_type -> management.Empty + 17, // 73: management.ManagementService.Logout:output_type -> management.Empty + 66, // [66:74] is the sub-list for method output_type + 58, // [58:66] is the sub-list for method input_type + 58, // [58:58] is the sub-list for extension type_name + 58, // [58:58] is the sub-list for extension extendee + 0, // [0:58] is the sub-list for field type_name } func init() { file_management_proto_init() } @@ -4457,7 +4684,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[20].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NetworkMap); i { + switch v := v.(*AutoUpdateSettings); i { case 0: return &v.state case 1: @@ -4469,7 +4696,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[21].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*RemotePeerConfig); i { + switch v := v.(*NetworkMap); i { case 0: return &v.state case 1: @@ -4481,7 +4708,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[22].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SSHConfig); i { + switch v := v.(*SSHAuth); i { case 0: return &v.state case 1: @@ -4493,7 +4720,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[23].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DeviceAuthorizationFlowRequest); i { + switch v := v.(*MachineUserIndexes); i { case 0: return &v.state case 1: @@ -4505,7 +4732,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[24].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DeviceAuthorizationFlow); i { + switch v := v.(*RemotePeerConfig); i { case 0: return &v.state case 1: @@ -4517,7 +4744,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[25].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PKCEAuthorizationFlowRequest); i { + switch v := v.(*SSHConfig); i { case 0: return &v.state case 1: @@ -4529,7 +4756,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[26].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PKCEAuthorizationFlow); i { + switch v := v.(*DeviceAuthorizationFlowRequest); i { case 0: return &v.state case 1: @@ -4541,7 +4768,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[27].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ProviderConfig); i { + switch v := v.(*DeviceAuthorizationFlow); i { case 0: return &v.state case 1: @@ -4553,7 +4780,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[28].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Route); i { + switch v := v.(*PKCEAuthorizationFlowRequest); i { case 0: return &v.state case 1: @@ -4565,7 +4792,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[29].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DNSConfig); i { + switch v := v.(*PKCEAuthorizationFlow); i { case 0: return &v.state case 1: @@ -4577,7 +4804,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[30].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CustomZone); i { + switch v := v.(*ProviderConfig); i { case 0: return &v.state case 1: @@ -4589,7 +4816,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[31].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SimpleRecord); i { + switch v := v.(*Route); i { case 0: return &v.state case 1: @@ -4601,7 +4828,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[32].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NameServerGroup); i { + switch v := v.(*DNSConfig); i { case 0: return &v.state case 1: @@ -4613,7 +4840,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[33].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NameServer); i { + switch v := v.(*CustomZone); i { case 0: return &v.state case 1: @@ -4625,7 +4852,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[34].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*FirewallRule); i { + switch v := v.(*SimpleRecord); i { case 0: return &v.state case 1: @@ -4637,7 +4864,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[35].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NetworkAddress); i { + switch v := v.(*NameServerGroup); i { case 0: return &v.state case 1: @@ -4649,7 +4876,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[36].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Checks); i { + switch v := v.(*NameServer); i { case 0: return &v.state case 1: @@ -4661,7 +4888,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[37].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PortInfo); i { + switch v := v.(*FirewallRule); i { case 0: return &v.state case 1: @@ -4673,7 +4900,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[38].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*RouteFirewallRule); i { + switch v := v.(*NetworkAddress); i { case 0: return &v.state case 1: @@ -4685,7 +4912,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[39].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ForwardingRule); i { + switch v := v.(*Checks); i { case 0: return &v.state case 1: @@ -4697,6 +4924,42 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[40].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*PortInfo); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_management_proto_msgTypes[41].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*RouteFirewallRule); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_management_proto_msgTypes[42].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ForwardingRule); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_management_proto_msgTypes[44].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*PortInfo_Range); i { case 0: return &v.state @@ -4709,7 +4972,7 @@ func file_management_proto_init() { } } } - file_management_proto_msgTypes[37].OneofWrappers = []interface{}{ + file_management_proto_msgTypes[40].OneofWrappers = []interface{}{ (*PortInfo_Port)(nil), (*PortInfo_Range_)(nil), } @@ -4719,7 +4982,7 @@ func file_management_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_management_proto_rawDesc, NumEnums: 5, - NumMessages: 41, + NumMessages: 45, NumExtensions: 0, NumServices: 1, }, diff --git a/shared/management/proto/management.proto b/shared/management/proto/management.proto index dc60b026d..f2e591e88 100644 --- a/shared/management/proto/management.proto +++ b/shared/management/proto/management.proto @@ -280,6 +280,18 @@ message PeerConfig { bool LazyConnectionEnabled = 6; int32 mtu = 7; + + // Auto-update config + AutoUpdateSettings autoUpdate = 8; +} + +message AutoUpdateSettings { + string version = 1; + /* + alwaysUpdate = true → Updates happen automatically in the background + alwaysUpdate = false → Updates only happen when triggered by a peer connection + */ + bool alwaysUpdate = 2; } // NetworkMap represents a network state of the peer with the corresponding configuration parameters to establish peer-to-peer connections @@ -320,6 +332,24 @@ message NetworkMap { bool routesFirewallRulesIsEmpty = 11; repeated ForwardingRule forwardingRules = 12; + + // SSHAuth represents SSH authorization configuration + SSHAuth sshAuth = 13; +} + +message SSHAuth { + // UserIDClaim is the JWT claim to be used to get the users ID + string UserIDClaim = 1; + + // AuthorizedUsers is a list of hashed user IDs authorized to access this peer via SSH + repeated bytes AuthorizedUsers = 2; + + // MachineUsers is a map of machine user names to their corresponding indexes in the AuthorizedUsers list + map machine_users = 3; +} + +message MachineUserIndexes { + repeated uint32 indexes = 1; } // RemotePeerConfig represents a configuration of a remote peer. diff --git a/shared/sshauth/userhash.go b/shared/sshauth/userhash.go new file mode 100644 index 000000000..276fc9ba2 --- /dev/null +++ b/shared/sshauth/userhash.go @@ -0,0 +1,28 @@ +package sshauth + +import ( + "encoding/hex" + + "golang.org/x/crypto/blake2b" +) + +// UserIDHash represents a hashed user ID (BLAKE2b-128) +type UserIDHash [16]byte + +// HashUserID hashes a user ID using BLAKE2b-128 and returns the hash value +// This function must produce the same hash on both client and management server +func HashUserID(userID string) (UserIDHash, error) { + hash, err := blake2b.New(16, nil) + if err != nil { + return UserIDHash{}, err + } + hash.Write([]byte(userID)) + var result UserIDHash + copy(result[:], hash.Sum(nil)) + return result, nil +} + +// String returns the hexadecimal string representation of the hash +func (h UserIDHash) String() string { + return hex.EncodeToString(h[:]) +} diff --git a/shared/sshauth/userhash_test.go b/shared/sshauth/userhash_test.go new file mode 100644 index 000000000..5a3cb6986 --- /dev/null +++ b/shared/sshauth/userhash_test.go @@ -0,0 +1,210 @@ +package sshauth + +import ( + "testing" +) + +func TestHashUserID(t *testing.T) { + tests := []struct { + name string + userID string + }{ + { + name: "simple user ID", + userID: "user@example.com", + }, + { + name: "UUID format", + userID: "550e8400-e29b-41d4-a716-446655440000", + }, + { + name: "numeric ID", + userID: "12345", + }, + { + name: "empty string", + userID: "", + }, + { + name: "special characters", + userID: "user+test@domain.com", + }, + { + name: "unicode characters", + userID: "用户@example.com", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hash, err := HashUserID(tt.userID) + if err != nil { + t.Errorf("HashUserID() error = %v, want nil", err) + return + } + + // Verify hash is non-zero for non-empty inputs + if tt.userID != "" && hash == [16]byte{} { + t.Errorf("HashUserID() returned zero hash for non-empty input") + } + }) + } +} + +func TestHashUserID_Consistency(t *testing.T) { + userID := "test@example.com" + + hash1, err1 := HashUserID(userID) + if err1 != nil { + t.Fatalf("First HashUserID() error = %v", err1) + } + + hash2, err2 := HashUserID(userID) + if err2 != nil { + t.Fatalf("Second HashUserID() error = %v", err2) + } + + if hash1 != hash2 { + t.Errorf("HashUserID() is not consistent: got %v and %v for same input", hash1, hash2) + } +} + +func TestHashUserID_Uniqueness(t *testing.T) { + tests := []struct { + userID1 string + userID2 string + }{ + {"user1@example.com", "user2@example.com"}, + {"alice@domain.com", "bob@domain.com"}, + {"test", "test1"}, + {"", "a"}, + } + + for _, tt := range tests { + hash1, err1 := HashUserID(tt.userID1) + if err1 != nil { + t.Fatalf("HashUserID(%s) error = %v", tt.userID1, err1) + } + + hash2, err2 := HashUserID(tt.userID2) + if err2 != nil { + t.Fatalf("HashUserID(%s) error = %v", tt.userID2, err2) + } + + if hash1 == hash2 { + t.Errorf("HashUserID() collision: %s and %s produced same hash %v", tt.userID1, tt.userID2, hash1) + } + } +} + +func TestUserIDHash_String(t *testing.T) { + tests := []struct { + name string + hash UserIDHash + expected string + }{ + { + name: "zero hash", + hash: [16]byte{}, + expected: "00000000000000000000000000000000", + }, + { + name: "small value", + hash: [16]byte{15: 0xff}, + expected: "000000000000000000000000000000ff", + }, + { + name: "large value", + hash: [16]byte{8: 0xde, 9: 0xad, 10: 0xbe, 11: 0xef, 12: 0xca, 13: 0xfe, 14: 0xba, 15: 0xbe}, + expected: "0000000000000000deadbeefcafebabe", + }, + { + name: "max value", + hash: [16]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + expected: "ffffffffffffffffffffffffffffffff", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.hash.String() + if result != tt.expected { + t.Errorf("UserIDHash.String() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestUserIDHash_String_Length(t *testing.T) { + // Test that String() always returns 32 hex characters (16 bytes * 2) + userID := "test@example.com" + hash, err := HashUserID(userID) + if err != nil { + t.Fatalf("HashUserID() error = %v", err) + } + + result := hash.String() + if len(result) != 32 { + t.Errorf("UserIDHash.String() length = %d, want 32", len(result)) + } + + // Verify it's valid hex + for i, c := range result { + if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) { + t.Errorf("UserIDHash.String() contains non-hex character at position %d: %c", i, c) + } + } +} + +func TestHashUserID_KnownValues(t *testing.T) { + // Test with known BLAKE2b-128 values to ensure correct implementation + tests := []struct { + name string + userID string + expected UserIDHash + }{ + { + name: "empty string", + userID: "", + // BLAKE2b-128 of empty string + expected: [16]byte{0xca, 0xe6, 0x69, 0x41, 0xd9, 0xef, 0xbd, 0x40, 0x4e, 0x4d, 0x88, 0x75, 0x8e, 0xa6, 0x76, 0x70}, + }, + { + name: "single character 'a'", + userID: "a", + // BLAKE2b-128 of "a" + expected: [16]byte{0x27, 0xc3, 0x5e, 0x6e, 0x93, 0x73, 0x87, 0x7f, 0x29, 0xe5, 0x62, 0x46, 0x4e, 0x46, 0x49, 0x7e}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hash, err := HashUserID(tt.userID) + if err != nil { + t.Errorf("HashUserID() error = %v", err) + return + } + + if hash != tt.expected { + t.Errorf("HashUserID(%q) = %x, want %x", + tt.userID, hash, tt.expected) + } + }) + } +} + +func BenchmarkHashUserID(b *testing.B) { + userID := "user@example.com" + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = HashUserID(userID) + } +} + +func BenchmarkUserIDHash_String(b *testing.B) { + hash := UserIDHash([16]byte{8: 0xde, 9: 0xad, 10: 0xbe, 11: 0xef, 12: 0xca, 13: 0xfe, 14: 0xba, 15: 0xbe}) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = hash.String() + } +} diff --git a/util/crypt/crypt.go b/util/crypt/crypt.go new file mode 100644 index 000000000..0e5589895 --- /dev/null +++ b/util/crypt/crypt.go @@ -0,0 +1,96 @@ +package crypt + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "fmt" + "io" +) + +// FieldEncrypt provides AES-GCM encryption for sensitive fields. +type FieldEncrypt struct { + block cipher.Block +} + +// NewFieldEncrypt creates a new FieldEncrypt with the given base64-encoded key. +// The key must be 32 bytes when decoded (for AES-256). +func NewFieldEncrypt(base64Key string) (*FieldEncrypt, error) { + key, err := base64.StdEncoding.DecodeString(base64Key) + if err != nil { + return nil, fmt.Errorf("decode encryption key: %w", err) + } + + if len(key) != 32 { + return nil, fmt.Errorf("encryption key must be 32 bytes, got %d", len(key)) + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("create cipher: %w", err) + } + + return &FieldEncrypt{block: block}, nil +} + +// Encrypt encrypts the given plaintext and returns base64-encoded ciphertext. +// Returns empty string for empty input. +func (f *FieldEncrypt) Encrypt(plaintext string) (string, error) { + if plaintext == "" { + return "", nil + } + + gcm, err := cipher.NewGCM(f.block) + if err != nil { + return "", fmt.Errorf("create GCM: %w", err) + } + + nonce := make([]byte, gcm.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return "", fmt.Errorf("generate nonce: %w", err) + } + + ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil) + return base64.StdEncoding.EncodeToString(ciphertext), nil +} + +// Decrypt decrypts the given base64-encoded ciphertext and returns the plaintext. +// Returns empty string for empty input. +func (f *FieldEncrypt) Decrypt(ciphertext string) (string, error) { + if ciphertext == "" { + return "", nil + } + + data, err := base64.StdEncoding.DecodeString(ciphertext) + if err != nil { + return "", fmt.Errorf("decode ciphertext: %w", err) + } + + gcm, err := cipher.NewGCM(f.block) + if err != nil { + return "", fmt.Errorf("create GCM: %w", err) + } + + nonceSize := gcm.NonceSize() + if len(data) < nonceSize { + return "", fmt.Errorf("ciphertext too short") + } + + nonce, ciphertextBytes := data[:nonceSize], data[nonceSize:] + plaintext, err := gcm.Open(nil, nonce, ciphertextBytes, nil) + if err != nil { + return "", fmt.Errorf("decrypt: %w", err) + } + + return string(plaintext), nil +} + +// GenerateKey generates a new random 32-byte encryption key and returns it as base64. +func GenerateKey() (string, error) { + key := make([]byte, 32) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + return "", fmt.Errorf("generate key: %w", err) + } + return base64.StdEncoding.EncodeToString(key), nil +} diff --git a/util/semaphore-group/semaphore_group.go b/util/semaphore-group/semaphore_group.go index ad74e1bfc..462300672 100644 --- a/util/semaphore-group/semaphore_group.go +++ b/util/semaphore-group/semaphore_group.go @@ -2,12 +2,10 @@ package semaphoregroup import ( "context" - "sync" ) // SemaphoreGroup is a custom type that combines sync.WaitGroup and a semaphore. type SemaphoreGroup struct { - waitGroup sync.WaitGroup semaphore chan struct{} } @@ -18,31 +16,18 @@ func NewSemaphoreGroup(limit int) *SemaphoreGroup { } } -// Add increments the internal WaitGroup counter and acquires a semaphore slot. -func (sg *SemaphoreGroup) Add(ctx context.Context) { - sg.waitGroup.Add(1) - +// Add acquire a slot +func (sg *SemaphoreGroup) Add(ctx context.Context) error { // Acquire semaphore slot select { case <-ctx.Done(): - return + return ctx.Err() case sg.semaphore <- struct{}{}: + return nil } } -// Done decrements the internal WaitGroup counter and releases a semaphore slot. -func (sg *SemaphoreGroup) Done(ctx context.Context) { - sg.waitGroup.Done() - - // Release semaphore slot - select { - case <-ctx.Done(): - return - case <-sg.semaphore: - } -} - -// Wait waits until the internal WaitGroup counter is zero. -func (sg *SemaphoreGroup) Wait() { - sg.waitGroup.Wait() +// Done releases a slot. Must be called after a successful Add. +func (sg *SemaphoreGroup) Done() { + <-sg.semaphore } diff --git a/util/semaphore-group/semaphore_group_test.go b/util/semaphore-group/semaphore_group_test.go index d4491cf77..9406da4a0 100644 --- a/util/semaphore-group/semaphore_group_test.go +++ b/util/semaphore-group/semaphore_group_test.go @@ -2,65 +2,89 @@ package semaphoregroup import ( "context" + "sync" "testing" "time" ) func TestSemaphoreGroup(t *testing.T) { - semGroup := NewSemaphoreGroup(2) - - for i := 0; i < 5; i++ { - semGroup.Add(context.Background()) - go func(id int) { - defer semGroup.Done(context.Background()) - - got := len(semGroup.semaphore) - if got == 0 { - t.Errorf("Expected semaphore length > 0 , got 0") - } - - time.Sleep(time.Millisecond) - t.Logf("Goroutine %d is running\n", id) - }(i) - } - - semGroup.Wait() - - want := 0 - got := len(semGroup.semaphore) - if got != want { - t.Errorf("Expected semaphore length %d, got %d", want, got) - } -} - -func TestSemaphoreGroupContext(t *testing.T) { semGroup := NewSemaphoreGroup(1) - semGroup.Add(context.Background()) - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + _ = semGroup.Add(context.Background()) + + ctxTimeout, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) t.Cleanup(cancel) - rChan := make(chan struct{}) - go func() { - semGroup.Add(ctx) - rChan <- struct{}{} - }() - select { - case <-rChan: - case <-time.NewTimer(2 * time.Second).C: - t.Error("Adding to semaphore group should not block when context is not done") - } - - semGroup.Done(context.Background()) - - ctxDone, cancelDone := context.WithTimeout(context.Background(), 1*time.Second) - t.Cleanup(cancelDone) - go func() { - semGroup.Done(ctxDone) - rChan <- struct{}{} - }() - select { - case <-rChan: - case <-time.NewTimer(2 * time.Second).C: - t.Error("Releasing from semaphore group should not block when context is not done") + if err := semGroup.Add(ctxTimeout); err == nil { + t.Error("Adding to semaphore group should not block") + } +} + +func TestSemaphoreGroupFreeUp(t *testing.T) { + semGroup := NewSemaphoreGroup(1) + _ = semGroup.Add(context.Background()) + semGroup.Done() + + ctxTimeout, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + t.Cleanup(cancel) + if err := semGroup.Add(ctxTimeout); err != nil { + t.Error(err) + } +} + +func TestSemaphoreGroupCanceledContext(t *testing.T) { + semGroup := NewSemaphoreGroup(1) + _ = semGroup.Add(context.Background()) + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + + if err := semGroup.Add(ctx); err == nil { + t.Error("Add should return error when context is already canceled") + } +} + +func TestSemaphoreGroupCancelWhileWaiting(t *testing.T) { + semGroup := NewSemaphoreGroup(1) + _ = semGroup.Add(context.Background()) + + ctx, cancel := context.WithCancel(context.Background()) + errChan := make(chan error, 1) + + go func() { + errChan <- semGroup.Add(ctx) + }() + + time.Sleep(10 * time.Millisecond) + cancel() + + if err := <-errChan; err == nil { + t.Error("Add should return error when context is canceled while waiting") + } +} + +func TestSemaphoreGroupHighConcurrency(t *testing.T) { + const limit = 10 + const numGoroutines = 100 + + semGroup := NewSemaphoreGroup(limit) + var wg sync.WaitGroup + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if err := semGroup.Add(context.Background()); err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + time.Sleep(time.Millisecond) + semGroup.Done() + }() + } + + wg.Wait() + + // Verify all slots were released + if got := len(semGroup.semaphore); got != 0 { + t.Errorf("Expected semaphore to be empty, got %d slots occupied", got) } } diff --git a/version/update.go b/version/update.go index 272eef4c6..a324d97fe 100644 --- a/version/update.go +++ b/version/update.go @@ -41,21 +41,28 @@ func NewUpdate(httpAgent string) *Update { currentVersion, _ = goversion.NewVersion("0.0.0") } - latestAvailable, _ := goversion.NewVersion("0.0.0") - u := &Update{ - httpAgent: httpAgent, - latestAvailable: latestAvailable, - uiVersion: currentVersion, - fetchTicker: time.NewTicker(fetchPeriod), - fetchDone: make(chan struct{}), + httpAgent: httpAgent, + uiVersion: currentVersion, + fetchDone: make(chan struct{}), } - go u.startFetcher() + + return u +} + +func NewUpdateAndStart(httpAgent string) *Update { + u := NewUpdate(httpAgent) + go u.StartFetcher() + return u } // StopWatch stop the version info fetch loop func (u *Update) StopWatch() { + if u.fetchTicker == nil { + return + } + u.fetchTicker.Stop() select { @@ -94,7 +101,18 @@ func (u *Update) SetOnUpdateListener(updateFn func()) { } } -func (u *Update) startFetcher() { +func (u *Update) LatestVersion() *goversion.Version { + u.versionsLock.Lock() + defer u.versionsLock.Unlock() + return u.latestAvailable +} + +func (u *Update) StartFetcher() { + if u.fetchTicker != nil { + return + } + u.fetchTicker = time.NewTicker(fetchPeriod) + if changed := u.fetchVersion(); changed { u.checkUpdate() } @@ -181,6 +199,10 @@ func (u *Update) isUpdateAvailable() bool { u.versionsLock.Lock() defer u.versionsLock.Unlock() + if u.latestAvailable == nil { + return false + } + if u.latestAvailable.GreaterThan(u.uiVersion) { return true } diff --git a/version/update_test.go b/version/update_test.go index a733714cf..d5d60800e 100644 --- a/version/update_test.go +++ b/version/update_test.go @@ -23,7 +23,7 @@ func TestNewUpdate(t *testing.T) { wg.Add(1) onUpdate := false - u := NewUpdate(httpAgent) + u := NewUpdateAndStart(httpAgent) defer u.StopWatch() u.SetOnUpdateListener(func() { onUpdate = true @@ -48,7 +48,7 @@ func TestDoNotUpdate(t *testing.T) { wg.Add(1) onUpdate := false - u := NewUpdate(httpAgent) + u := NewUpdateAndStart(httpAgent) defer u.StopWatch() u.SetOnUpdateListener(func() { onUpdate = true @@ -73,7 +73,7 @@ func TestDaemonUpdate(t *testing.T) { wg.Add(1) onUpdate := false - u := NewUpdate(httpAgent) + u := NewUpdateAndStart(httpAgent) defer u.StopWatch() u.SetOnUpdateListener(func() { onUpdate = true