diff --git a/.github/workflows/golang-test-freebsd.yml b/.github/workflows/golang-test-freebsd.yml
index b03313bbd..0d19e8a19 100644
--- a/.github/workflows/golang-test-freebsd.yml
+++ b/.github/workflows/golang-test-freebsd.yml
@@ -39,7 +39,7 @@ jobs:
# check all component except management, since we do not support management server on freebsd
time go test -timeout 1m -failfast ./base62/...
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
- time go test -timeout 8m -failfast -p 1 ./client/...
+ time go test -timeout 8m -failfast -v -p 1 ./client/...
time go test -timeout 1m -failfast ./dns/...
time go test -timeout 1m -failfast ./encryption/...
time go test -timeout 1m -failfast ./formatter/...
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index a9bc1b979..2fa847dce 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -9,7 +9,7 @@ on:
pull_request:
env:
- SIGN_PIPE_VER: "v0.0.23"
+ SIGN_PIPE_VER: "v0.1.0"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"
@@ -19,6 +19,100 @@ concurrency:
cancel-in-progress: true
jobs:
+ release_freebsd_port:
+ name: "FreeBSD Port / Build & Test"
+ runs-on: ubuntu-22.04
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v4
+
+ - name: Generate FreeBSD port diff
+ run: bash release_files/freebsd-port-diff.sh
+
+ - name: Generate FreeBSD port issue body
+ run: bash release_files/freebsd-port-issue-body.sh
+
+ - name: Check if diff was generated
+ id: check_diff
+ run: |
+ if ls netbird-*.diff 1> /dev/null 2>&1; then
+ echo "diff_exists=true" >> $GITHUB_OUTPUT
+ else
+ echo "diff_exists=false" >> $GITHUB_OUTPUT
+ echo "No diff file generated (port may already be up to date)"
+ fi
+
+ - name: Extract version
+ if: steps.check_diff.outputs.diff_exists == 'true'
+ id: version
+ run: |
+ VERSION=$(ls netbird-*.diff | sed 's/netbird-\(.*\)\.diff/\1/')
+ echo "version=$VERSION" >> $GITHUB_OUTPUT
+ echo "Generated files for version: $VERSION"
+ cat netbird-*.diff
+
+ - name: Test FreeBSD port
+ if: steps.check_diff.outputs.diff_exists == 'true'
+ uses: vmactions/freebsd-vm@v1
+ with:
+ usesh: true
+ copyback: false
+ release: "15.0"
+ prepare: |
+ # Install required packages
+ pkg install -y git curl portlint go
+
+ # Install Go for building
+ GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz"
+ GO_URL="https://go.dev/dl/$GO_TARBALL"
+ curl -LO "$GO_URL"
+ tar -C /usr/local -xzf "$GO_TARBALL"
+
+ # Clone ports tree (shallow, only what we need)
+ git clone --depth 1 --filter=blob:none https://git.FreeBSD.org/ports.git /usr/ports
+ cd /usr/ports
+
+ run: |
+ set -e -x
+ export PATH=$PATH:/usr/local/go/bin
+
+ # Find the diff file
+ echo "Finding diff file..."
+ DIFF_FILE=$(find $PWD -name "netbird-*.diff" -type f 2>/dev/null | head -1)
+ echo "Found: $DIFF_FILE"
+
+ if [[ -z "$DIFF_FILE" ]]; then
+ echo "ERROR: Could not find diff file"
+ find ~ -name "*.diff" -type f 2>/dev/null || true
+ exit 1
+ fi
+
+ # Apply the generated diff from /usr/ports (diff has a/security/netbird/... paths)
+ cd /usr/ports
+ patch -p1 -V none < "$DIFF_FILE"
+
+ # Show patched Makefile
+ version=$(cat security/netbird/Makefile | grep -E '^DISTVERSION=' | awk '{print $NF}')
+
+ cd /usr/ports/security/netbird
+ export BATCH=yes
+ make package
+ pkg add ./work/pkg/netbird-*.pkg
+
+ netbird version | grep "$version"
+
+ echo "FreeBSD port test completed successfully!"
+
+ - name: Upload FreeBSD port files
+ if: steps.check_diff.outputs.diff_exists == 'true'
+ uses: actions/upload-artifact@v4
+ with:
+ name: freebsd-port-files
+ path: |
+ ./netbird-*-issue.txt
+ ./netbird-*.diff
+ retention-days: 30
+
release:
runs-on: ubuntu-latest-m
env:
diff --git a/.github/workflows/test-infrastructure-files.yml b/.github/workflows/test-infrastructure-files.yml
index f4513e0e1..e2f950731 100644
--- a/.github/workflows/test-infrastructure-files.yml
+++ b/.github/workflows/test-infrastructure-files.yml
@@ -243,6 +243,7 @@ jobs:
working-directory: infrastructure_files/artifacts
run: |
sleep 30
+ docker compose logs
docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City_[0-9]*.mmdb
docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames_[0-9]*.db
diff --git a/.gitignore b/.gitignore
index e6c0c0aca..89024d190 100644
--- a/.gitignore
+++ b/.gitignore
@@ -31,3 +31,4 @@ infrastructure_files/setup-*.env
.DS_Store
vendor/
/netbird
+client/netbird-electron/
diff --git a/.goreleaser.yaml b/.goreleaser.yaml
index 952e946dc..7c6651f83 100644
--- a/.goreleaser.yaml
+++ b/.goreleaser.yaml
@@ -713,8 +713,10 @@ checksum:
extra_files:
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
- glob: ./release_files/install.sh
+ - glob: ./infrastructure_files/getting-started.sh
release:
extra_files:
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
- glob: ./release_files/install.sh
+ - glob: ./infrastructure_files/getting-started.sh
diff --git a/README.md b/README.md
index 2c5ee2ab6..28b53d5b6 100644
--- a/README.md
+++ b/README.md
@@ -85,7 +85,7 @@ Follow the [Advanced guide with a custom identity provider](https://docs.netbird
**Infrastructure requirements:**
- A Linux VM with at least **1CPU** and **2GB** of memory.
-- The VM should be publicly accessible on TCP ports **80** and **443** and UDP ports: **3478**, **49152-65535**.
+- The VM should be publicly accessible on TCP ports **80** and **443** and UDP port: **3478**.
- **Public domain** name pointing to the VM.
**Software requirements:**
@@ -98,7 +98,7 @@ Follow the [Advanced guide with a custom identity provider](https://docs.netbird
**Steps**
- Download and run the installation script:
```bash
-export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started-with-zitadel.sh | bash
+export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started.sh | bash
```
- Once finished, you can manage the resources via `docker-compose`
@@ -113,7 +113,7 @@ export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbird
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
-
+
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
diff --git a/client/android/client.go b/client/android/client.go
index 0d5474c4b..ccf32a90c 100644
--- a/client/android/client.go
+++ b/client/android/client.go
@@ -59,7 +59,6 @@ func init() {
// Client struct manage the life circle of background service
type Client struct {
- cfgFile string
tunAdapter device.TunAdapter
iFaceDiscover IFaceDiscover
recorder *peer.Status
@@ -68,18 +67,16 @@ type Client struct {
deviceName string
uiVersion string
networkChangeListener listener.NetworkChangeListener
- stateFile string
connectClient *internal.ConnectClient
}
// NewClient instantiate a new Client
-func NewClient(platformFiles PlatformFiles, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
+func NewClient(androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
execWorkaround(androidSDKVersion)
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
return &Client{
- cfgFile: platformFiles.ConfigurationFilePath(),
deviceName: deviceName,
uiVersion: uiVersion,
tunAdapter: tunAdapter,
@@ -87,15 +84,20 @@ func NewClient(platformFiles PlatformFiles, androidSDKVersion int, deviceName st
recorder: peer.NewRecorder(""),
ctxCancelLock: &sync.Mutex{},
networkChangeListener: networkChangeListener,
- stateFile: platformFiles.StateFilePath(),
}
}
// Run start the internal client. It is a blocker function
-func (c *Client) Run(urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
+func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
exportEnvList(envList)
+
+ cfgFile := platformFiles.ConfigurationFilePath()
+ stateFile := platformFiles.StateFilePath()
+
+ log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile)
+
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
- ConfigPath: c.cfgFile,
+ ConfigPath: cfgFile,
})
if err != nil {
return err
@@ -122,16 +124,22 @@ func (c *Client) Run(urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsRea
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
- c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
- return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile)
+ c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
+ return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
}
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
// In this case make no sense handle registration steps.
-func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
+func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
exportEnvList(envList)
+
+ cfgFile := platformFiles.ConfigurationFilePath()
+ stateFile := platformFiles.StateFilePath()
+
+ log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile)
+
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
- ConfigPath: c.cfgFile,
+ ConfigPath: cfgFile,
})
if err != nil {
return err
@@ -149,8 +157,8 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
- c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
- return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile)
+ c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
+ return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
}
// Stop the internal client and free the resources
diff --git a/client/android/profile_manager.go b/client/android/profile_manager.go
new file mode 100644
index 000000000..60e4d5c32
--- /dev/null
+++ b/client/android/profile_manager.go
@@ -0,0 +1,257 @@
+//go:build android
+
+package android
+
+import (
+ "fmt"
+ "os"
+ "path/filepath"
+ "strings"
+
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/client/internal/profilemanager"
+)
+
+const (
+ // Android-specific config filename (different from desktop default.json)
+ defaultConfigFilename = "netbird.cfg"
+ // Subdirectory for non-default profiles (must match Java Preferences.java)
+ profilesSubdir = "profiles"
+ // Android uses a single user context per app (non-empty username required by ServiceManager)
+ androidUsername = "android"
+)
+
+// Profile represents a profile for gomobile
+type Profile struct {
+ Name string
+ IsActive bool
+}
+
+// ProfileArray wraps profiles for gomobile compatibility
+type ProfileArray struct {
+ items []*Profile
+}
+
+// Length returns the number of profiles
+func (p *ProfileArray) Length() int {
+ return len(p.items)
+}
+
+// Get returns the profile at index i
+func (p *ProfileArray) Get(i int) *Profile {
+ if i < 0 || i >= len(p.items) {
+ return nil
+ }
+ return p.items[i]
+}
+
+/*
+
+/data/data/io.netbird.client/files/ ← configDir parameter
+├── netbird.cfg ← Default profile config
+├── state.json ← Default profile state
+├── active_profile.json ← Active profile tracker (JSON with Name + Username)
+└── profiles/ ← Subdirectory for non-default profiles
+ ├── work.json ← Work profile config
+ ├── work.state.json ← Work profile state
+ ├── personal.json ← Personal profile config
+ └── personal.state.json ← Personal profile state
+*/
+
+// ProfileManager manages profiles for Android
+// It wraps the internal profilemanager to provide Android-specific behavior
+type ProfileManager struct {
+ configDir string
+ serviceMgr *profilemanager.ServiceManager
+}
+
+// NewProfileManager creates a new profile manager for Android
+func NewProfileManager(configDir string) *ProfileManager {
+ // Set the default config path for Android (stored in root configDir, not profiles/)
+ defaultConfigPath := filepath.Join(configDir, defaultConfigFilename)
+
+ // Set global paths for Android
+ profilemanager.DefaultConfigPathDir = configDir
+ profilemanager.DefaultConfigPath = defaultConfigPath
+ profilemanager.ActiveProfileStatePath = filepath.Join(configDir, "active_profile.json")
+
+ // Create ServiceManager with profiles/ subdirectory
+ // This avoids modifying the global ConfigDirOverride for profile listing
+ profilesDir := filepath.Join(configDir, profilesSubdir)
+ serviceMgr := profilemanager.NewServiceManagerWithProfilesDir(defaultConfigPath, profilesDir)
+
+ return &ProfileManager{
+ configDir: configDir,
+ serviceMgr: serviceMgr,
+ }
+}
+
+// ListProfiles returns all available profiles
+func (pm *ProfileManager) ListProfiles() (*ProfileArray, error) {
+ // Use ServiceManager (looks in profiles/ directory, checks active_profile.json for IsActive)
+ internalProfiles, err := pm.serviceMgr.ListProfiles(androidUsername)
+ if err != nil {
+ return nil, fmt.Errorf("failed to list profiles: %w", err)
+ }
+
+ // Convert internal profiles to Android Profile type
+ var profiles []*Profile
+ for _, p := range internalProfiles {
+ profiles = append(profiles, &Profile{
+ Name: p.Name,
+ IsActive: p.IsActive,
+ })
+ }
+
+ return &ProfileArray{items: profiles}, nil
+}
+
+// GetActiveProfile returns the currently active profile name
+func (pm *ProfileManager) GetActiveProfile() (string, error) {
+ // Use ServiceManager to stay consistent with ListProfiles
+ // ServiceManager uses active_profile.json
+ activeState, err := pm.serviceMgr.GetActiveProfileState()
+ if err != nil {
+ return "", fmt.Errorf("failed to get active profile: %w", err)
+ }
+ return activeState.Name, nil
+}
+
+// SwitchProfile switches to a different profile
+func (pm *ProfileManager) SwitchProfile(profileName string) error {
+ // Use ServiceManager to stay consistent with ListProfiles
+ // ServiceManager uses active_profile.json
+ err := pm.serviceMgr.SetActiveProfileState(&profilemanager.ActiveProfileState{
+ Name: profileName,
+ Username: androidUsername,
+ })
+ if err != nil {
+ return fmt.Errorf("failed to switch profile: %w", err)
+ }
+
+ log.Infof("switched to profile: %s", profileName)
+ return nil
+}
+
+// AddProfile creates a new profile
+func (pm *ProfileManager) AddProfile(profileName string) error {
+ // Use ServiceManager (creates profile in profiles/ directory)
+ if err := pm.serviceMgr.AddProfile(profileName, androidUsername); err != nil {
+ return fmt.Errorf("failed to add profile: %w", err)
+ }
+
+ log.Infof("created new profile: %s", profileName)
+ return nil
+}
+
+// LogoutProfile logs out from a profile (clears authentication)
+func (pm *ProfileManager) LogoutProfile(profileName string) error {
+ profileName = sanitizeProfileName(profileName)
+
+ configPath, err := pm.getProfileConfigPath(profileName)
+ if err != nil {
+ return err
+ }
+
+ // Check if profile exists
+ if _, err := os.Stat(configPath); os.IsNotExist(err) {
+ return fmt.Errorf("profile '%s' does not exist", profileName)
+ }
+
+ // Read current config using internal profilemanager
+ config, err := profilemanager.ReadConfig(configPath)
+ if err != nil {
+ return fmt.Errorf("failed to read profile config: %w", err)
+ }
+
+ // Clear authentication by removing private key and SSH key
+ config.PrivateKey = ""
+ config.SSHKey = ""
+
+ // Save config using internal profilemanager
+ if err := profilemanager.WriteOutConfig(configPath, config); err != nil {
+ return fmt.Errorf("failed to save config: %w", err)
+ }
+
+ log.Infof("logged out from profile: %s", profileName)
+ return nil
+}
+
+// RemoveProfile deletes a profile
+func (pm *ProfileManager) RemoveProfile(profileName string) error {
+ // Use ServiceManager (removes profile from profiles/ directory)
+ if err := pm.serviceMgr.RemoveProfile(profileName, androidUsername); err != nil {
+ return fmt.Errorf("failed to remove profile: %w", err)
+ }
+
+ log.Infof("removed profile: %s", profileName)
+ return nil
+}
+
+// getProfileConfigPath returns the config file path for a profile
+// This is needed for Android-specific path handling (netbird.cfg for default profile)
+func (pm *ProfileManager) getProfileConfigPath(profileName string) (string, error) {
+ if profileName == "" || profileName == profilemanager.DefaultProfileName {
+ // Android uses netbird.cfg for default profile instead of default.json
+ // Default profile is stored in root configDir, not in profiles/
+ return filepath.Join(pm.configDir, defaultConfigFilename), nil
+ }
+
+ // Non-default profiles are stored in profiles subdirectory
+ // This matches the Java Preferences.java expectation
+ profileName = sanitizeProfileName(profileName)
+ profilesDir := filepath.Join(pm.configDir, profilesSubdir)
+ return filepath.Join(profilesDir, profileName+".json"), nil
+}
+
+// GetConfigPath returns the config file path for a given profile
+// Java should call this instead of constructing paths with Preferences.configFile()
+func (pm *ProfileManager) GetConfigPath(profileName string) (string, error) {
+ return pm.getProfileConfigPath(profileName)
+}
+
+// GetStateFilePath returns the state file path for a given profile
+// Java should call this instead of constructing paths with Preferences.stateFile()
+func (pm *ProfileManager) GetStateFilePath(profileName string) (string, error) {
+ if profileName == "" || profileName == profilemanager.DefaultProfileName {
+ return filepath.Join(pm.configDir, "state.json"), nil
+ }
+
+ profileName = sanitizeProfileName(profileName)
+ profilesDir := filepath.Join(pm.configDir, profilesSubdir)
+ return filepath.Join(profilesDir, profileName+".state.json"), nil
+}
+
+// GetActiveConfigPath returns the config file path for the currently active profile
+// Java should call this instead of Preferences.getActiveProfileName() + Preferences.configFile()
+func (pm *ProfileManager) GetActiveConfigPath() (string, error) {
+ activeProfile, err := pm.GetActiveProfile()
+ if err != nil {
+ return "", fmt.Errorf("failed to get active profile: %w", err)
+ }
+ return pm.GetConfigPath(activeProfile)
+}
+
+// GetActiveStateFilePath returns the state file path for the currently active profile
+// Java should call this instead of Preferences.getActiveProfileName() + Preferences.stateFile()
+func (pm *ProfileManager) GetActiveStateFilePath() (string, error) {
+ activeProfile, err := pm.GetActiveProfile()
+ if err != nil {
+ return "", fmt.Errorf("failed to get active profile: %w", err)
+ }
+ return pm.GetStateFilePath(activeProfile)
+}
+
+// sanitizeProfileName removes invalid characters from profile name
+func sanitizeProfileName(name string) string {
+ // Keep only alphanumeric, underscore, and hyphen
+ var result strings.Builder
+ for _, r := range name {
+ if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') ||
+ (r >= '0' && r <= '9') || r == '_' || r == '-' {
+ result.WriteRune(r)
+ }
+ }
+ return result.String()
+}
diff --git a/client/cmd/root.go b/client/cmd/root.go
index 9f2eb109c..30120c196 100644
--- a/client/cmd/root.go
+++ b/client/cmd/root.go
@@ -85,6 +85,9 @@ var (
// Execute executes the root command.
func Execute() error {
+ if isUpdateBinary() {
+ return updateCmd.Execute()
+ }
return rootCmd.Execute()
}
diff --git a/client/cmd/signer/artifactkey.go b/client/cmd/signer/artifactkey.go
new file mode 100644
index 000000000..5e656650b
--- /dev/null
+++ b/client/cmd/signer/artifactkey.go
@@ -0,0 +1,176 @@
+package main
+
+import (
+ "fmt"
+ "os"
+ "time"
+
+ "github.com/spf13/cobra"
+
+ "github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
+)
+
+var (
+ bundlePubKeysRootPrivKeyFile string
+ bundlePubKeysPubKeyFiles []string
+ bundlePubKeysFile string
+
+ createArtifactKeyRootPrivKeyFile string
+ createArtifactKeyPrivKeyFile string
+ createArtifactKeyPubKeyFile string
+ createArtifactKeyExpiration time.Duration
+)
+
+var createArtifactKeyCmd = &cobra.Command{
+ Use: "create-artifact-key",
+ Short: "Create a new artifact signing key",
+ Long: `Generate a new artifact signing key pair signed by the root private key.
+The artifact key will be used to sign software artifacts/updates.`,
+ SilenceUsage: true,
+ RunE: func(cmd *cobra.Command, args []string) error {
+ if createArtifactKeyExpiration <= 0 {
+ return fmt.Errorf("--expiration must be a positive duration (e.g., 720h, 365d, 8760h)")
+ }
+
+ if err := handleCreateArtifactKey(cmd, createArtifactKeyRootPrivKeyFile, createArtifactKeyPrivKeyFile, createArtifactKeyPubKeyFile, createArtifactKeyExpiration); err != nil {
+ return fmt.Errorf("failed to create artifact key: %w", err)
+ }
+ return nil
+ },
+}
+
+var bundlePubKeysCmd = &cobra.Command{
+ Use: "bundle-pub-keys",
+ Short: "Bundle multiple artifact public keys into a signed package",
+ Long: `Bundle one or more artifact public keys into a signed package using the root private key.
+This command is typically used to distribute or authorize a set of valid artifact signing keys.`,
+ RunE: func(cmd *cobra.Command, args []string) error {
+ if len(bundlePubKeysPubKeyFiles) == 0 {
+ return fmt.Errorf("at least one --artifact-pub-key-file must be provided")
+ }
+
+ if err := handleBundlePubKeys(cmd, bundlePubKeysRootPrivKeyFile, bundlePubKeysPubKeyFiles, bundlePubKeysFile); err != nil {
+ return fmt.Errorf("failed to bundle public keys: %w", err)
+ }
+ return nil
+ },
+}
+
+func init() {
+ rootCmd.AddCommand(createArtifactKeyCmd)
+
+ createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyRootPrivKeyFile, "root-private-key-file", "", "Path to the root private key file used to sign the artifact key")
+ createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyPrivKeyFile, "artifact-priv-key-file", "", "Path where the artifact private key will be saved")
+ createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyPubKeyFile, "artifact-pub-key-file", "", "Path where the artifact public key will be saved")
+ createArtifactKeyCmd.Flags().DurationVar(&createArtifactKeyExpiration, "expiration", 0, "Expiration duration for the artifact key (e.g., 720h, 365d, 8760h)")
+
+ if err := createArtifactKeyCmd.MarkFlagRequired("root-private-key-file"); err != nil {
+ panic(fmt.Errorf("mark root-private-key-file as required: %w", err))
+ }
+ if err := createArtifactKeyCmd.MarkFlagRequired("artifact-priv-key-file"); err != nil {
+ panic(fmt.Errorf("mark artifact-priv-key-file as required: %w", err))
+ }
+ if err := createArtifactKeyCmd.MarkFlagRequired("artifact-pub-key-file"); err != nil {
+ panic(fmt.Errorf("mark artifact-pub-key-file as required: %w", err))
+ }
+ if err := createArtifactKeyCmd.MarkFlagRequired("expiration"); err != nil {
+ panic(fmt.Errorf("mark expiration as required: %w", err))
+ }
+
+ rootCmd.AddCommand(bundlePubKeysCmd)
+
+ bundlePubKeysCmd.Flags().StringVar(&bundlePubKeysRootPrivKeyFile, "root-private-key-file", "", "Path to the root private key file used to sign the bundle")
+ bundlePubKeysCmd.Flags().StringArrayVar(&bundlePubKeysPubKeyFiles, "artifact-pub-key-file", nil, "Path(s) to the artifact public key files to include in the bundle (can be repeated)")
+ bundlePubKeysCmd.Flags().StringVar(&bundlePubKeysFile, "bundle-pub-key-file", "", "Path where the public keys will be saved")
+
+ if err := bundlePubKeysCmd.MarkFlagRequired("root-private-key-file"); err != nil {
+ panic(fmt.Errorf("mark root-private-key-file as required: %w", err))
+ }
+ if err := bundlePubKeysCmd.MarkFlagRequired("artifact-pub-key-file"); err != nil {
+ panic(fmt.Errorf("mark artifact-pub-key-file as required: %w", err))
+ }
+ if err := bundlePubKeysCmd.MarkFlagRequired("bundle-pub-key-file"); err != nil {
+ panic(fmt.Errorf("mark bundle-pub-key-file as required: %w", err))
+ }
+}
+
+func handleCreateArtifactKey(cmd *cobra.Command, rootPrivKeyFile, artifactPrivKeyFile, artifactPubKeyFile string, expiration time.Duration) error {
+ cmd.Println("Creating new artifact signing key...")
+
+ privKeyPEM, err := os.ReadFile(rootPrivKeyFile)
+ if err != nil {
+ return fmt.Errorf("read root private key file: %w", err)
+ }
+
+ privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
+ if err != nil {
+ return fmt.Errorf("failed to parse private root key: %w", err)
+ }
+
+ artifactKey, privPEM, pubPEM, signature, err := reposign.GenerateArtifactKey(privateRootKey, expiration)
+ if err != nil {
+ return fmt.Errorf("generate artifact key: %w", err)
+ }
+
+ if err := os.WriteFile(artifactPrivKeyFile, privPEM, 0o600); err != nil {
+ return fmt.Errorf("write private key file (%s): %w", artifactPrivKeyFile, err)
+ }
+
+ if err := os.WriteFile(artifactPubKeyFile, pubPEM, 0o600); err != nil {
+ return fmt.Errorf("write public key file (%s): %w", artifactPubKeyFile, err)
+ }
+
+ signatureFile := artifactPubKeyFile + ".sig"
+ if err := os.WriteFile(signatureFile, signature, 0o600); err != nil {
+ return fmt.Errorf("write signature file (%s): %w", signatureFile, err)
+ }
+
+ cmd.Printf("✅ Artifact key created successfully.\n")
+ cmd.Printf("%s\n", artifactKey.String())
+ return nil
+}
+
+func handleBundlePubKeys(cmd *cobra.Command, rootPrivKeyFile string, artifactPubKeyFiles []string, bundlePubKeysFile string) error {
+ cmd.Println("📦 Bundling public keys into signed package...")
+
+ privKeyPEM, err := os.ReadFile(rootPrivKeyFile)
+ if err != nil {
+ return fmt.Errorf("read root private key file: %w", err)
+ }
+
+ privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
+ if err != nil {
+ return fmt.Errorf("failed to parse private root key: %w", err)
+ }
+
+ publicKeys := make([]reposign.PublicKey, 0, len(artifactPubKeyFiles))
+ for _, pubFile := range artifactPubKeyFiles {
+ pubPem, err := os.ReadFile(pubFile)
+ if err != nil {
+ return fmt.Errorf("read public key file: %w", err)
+ }
+
+ pk, err := reposign.ParseArtifactPubKey(pubPem)
+ if err != nil {
+ return fmt.Errorf("failed to parse artifact key: %w", err)
+ }
+ publicKeys = append(publicKeys, pk)
+ }
+
+ parsedKeys, signature, err := reposign.BundleArtifactKeys(privateRootKey, publicKeys)
+ if err != nil {
+ return fmt.Errorf("bundle artifact keys: %w", err)
+ }
+
+ if err := os.WriteFile(bundlePubKeysFile, parsedKeys, 0o600); err != nil {
+ return fmt.Errorf("write public keys file (%s): %w", bundlePubKeysFile, err)
+ }
+
+ signatureFile := bundlePubKeysFile + ".sig"
+ if err := os.WriteFile(signatureFile, signature, 0o600); err != nil {
+ return fmt.Errorf("write signature file (%s): %w", signatureFile, err)
+ }
+
+ cmd.Printf("✅ Bundle created with %d public keys.\n", len(artifactPubKeyFiles))
+ return nil
+}
diff --git a/client/cmd/signer/artifactsign.go b/client/cmd/signer/artifactsign.go
new file mode 100644
index 000000000..881be9367
--- /dev/null
+++ b/client/cmd/signer/artifactsign.go
@@ -0,0 +1,276 @@
+package main
+
+import (
+ "fmt"
+ "os"
+
+ "github.com/spf13/cobra"
+
+ "github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
+)
+
+const (
+ envArtifactPrivateKey = "NB_ARTIFACT_PRIV_KEY"
+)
+
+var (
+ signArtifactPrivKeyFile string
+ signArtifactArtifactFile string
+
+ verifyArtifactPubKeyFile string
+ verifyArtifactFile string
+ verifyArtifactSignatureFile string
+
+ verifyArtifactKeyPubKeyFile string
+ verifyArtifactKeyRootPubKeyFile string
+ verifyArtifactKeySignatureFile string
+ verifyArtifactKeyRevocationFile string
+)
+
+var signArtifactCmd = &cobra.Command{
+ Use: "sign-artifact",
+ Short: "Sign an artifact using an artifact private key",
+ Long: `Sign a software artifact (e.g., update bundle or binary) using the artifact's private key.
+This command produces a detached signature that can be verified using the corresponding artifact public key.`,
+ SilenceUsage: true,
+ RunE: func(cmd *cobra.Command, args []string) error {
+ if err := handleSignArtifact(cmd, signArtifactPrivKeyFile, signArtifactArtifactFile); err != nil {
+ return fmt.Errorf("failed to sign artifact: %w", err)
+ }
+ return nil
+ },
+}
+
+var verifyArtifactCmd = &cobra.Command{
+ Use: "verify-artifact",
+ Short: "Verify an artifact signature using an artifact public key",
+ Long: `Verify a software artifact signature using the artifact's public key.`,
+ SilenceUsage: true,
+ RunE: func(cmd *cobra.Command, args []string) error {
+ if err := handleVerifyArtifact(cmd, verifyArtifactPubKeyFile, verifyArtifactFile, verifyArtifactSignatureFile); err != nil {
+ return fmt.Errorf("failed to verify artifact: %w", err)
+ }
+ return nil
+ },
+}
+
+var verifyArtifactKeyCmd = &cobra.Command{
+ Use: "verify-artifact-key",
+ Short: "Verify an artifact public key was signed by a root key",
+ Long: `Verify that an artifact public key (or bundle) was properly signed by a root key.
+This validates the chain of trust from the root key to the artifact key.`,
+ SilenceUsage: true,
+ RunE: func(cmd *cobra.Command, args []string) error {
+ if err := handleVerifyArtifactKey(cmd, verifyArtifactKeyPubKeyFile, verifyArtifactKeyRootPubKeyFile, verifyArtifactKeySignatureFile, verifyArtifactKeyRevocationFile); err != nil {
+ return fmt.Errorf("failed to verify artifact key: %w", err)
+ }
+ return nil
+ },
+}
+
+func init() {
+ rootCmd.AddCommand(signArtifactCmd)
+ rootCmd.AddCommand(verifyArtifactCmd)
+ rootCmd.AddCommand(verifyArtifactKeyCmd)
+
+ signArtifactCmd.Flags().StringVar(&signArtifactPrivKeyFile, "artifact-key-file", "", fmt.Sprintf("Path to the artifact private key file used for signing (or set %s env var)", envArtifactPrivateKey))
+ signArtifactCmd.Flags().StringVar(&signArtifactArtifactFile, "artifact-file", "", "Path to the artifact to be signed")
+
+ // artifact-file is required, but artifact-key-file can come from env var
+ if err := signArtifactCmd.MarkFlagRequired("artifact-file"); err != nil {
+ panic(fmt.Errorf("mark artifact-file as required: %w", err))
+ }
+
+ verifyArtifactCmd.Flags().StringVar(&verifyArtifactPubKeyFile, "artifact-public-key-file", "", "Path to the artifact public key file")
+ verifyArtifactCmd.Flags().StringVar(&verifyArtifactFile, "artifact-file", "", "Path to the artifact to be verified")
+ verifyArtifactCmd.Flags().StringVar(&verifyArtifactSignatureFile, "signature-file", "", "Path to the signature file")
+
+ if err := verifyArtifactCmd.MarkFlagRequired("artifact-public-key-file"); err != nil {
+ panic(fmt.Errorf("mark artifact-public-key-file as required: %w", err))
+ }
+ if err := verifyArtifactCmd.MarkFlagRequired("artifact-file"); err != nil {
+ panic(fmt.Errorf("mark artifact-file as required: %w", err))
+ }
+ if err := verifyArtifactCmd.MarkFlagRequired("signature-file"); err != nil {
+ panic(fmt.Errorf("mark signature-file as required: %w", err))
+ }
+
+ verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyPubKeyFile, "artifact-key-file", "", "Path to the artifact public key file or bundle")
+ verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyRootPubKeyFile, "root-key-file", "", "Path to the root public key file or bundle")
+ verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeySignatureFile, "signature-file", "", "Path to the signature file")
+ verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyRevocationFile, "revocation-file", "", "Path to the revocation list file (optional)")
+
+ if err := verifyArtifactKeyCmd.MarkFlagRequired("artifact-key-file"); err != nil {
+ panic(fmt.Errorf("mark artifact-key-file as required: %w", err))
+ }
+ if err := verifyArtifactKeyCmd.MarkFlagRequired("root-key-file"); err != nil {
+ panic(fmt.Errorf("mark root-key-file as required: %w", err))
+ }
+ if err := verifyArtifactKeyCmd.MarkFlagRequired("signature-file"); err != nil {
+ panic(fmt.Errorf("mark signature-file as required: %w", err))
+ }
+}
+
+func handleSignArtifact(cmd *cobra.Command, privKeyFile, artifactFile string) error {
+ cmd.Println("🖋️ Signing artifact...")
+
+ // Load private key from env var or file
+ var privKeyPEM []byte
+ var err error
+
+ if envKey := os.Getenv(envArtifactPrivateKey); envKey != "" {
+ // Use key from environment variable
+ privKeyPEM = []byte(envKey)
+ } else if privKeyFile != "" {
+ // Fall back to file
+ privKeyPEM, err = os.ReadFile(privKeyFile)
+ if err != nil {
+ return fmt.Errorf("read private key file: %w", err)
+ }
+ } else {
+ return fmt.Errorf("artifact private key must be provided via %s environment variable or --artifact-key-file flag", envArtifactPrivateKey)
+ }
+
+ privateKey, err := reposign.ParseArtifactKey(privKeyPEM)
+ if err != nil {
+ return fmt.Errorf("failed to parse artifact private key: %w", err)
+ }
+
+ artifactData, err := os.ReadFile(artifactFile)
+ if err != nil {
+ return fmt.Errorf("read artifact file: %w", err)
+ }
+
+ signature, err := reposign.SignData(privateKey, artifactData)
+ if err != nil {
+ return fmt.Errorf("sign artifact: %w", err)
+ }
+
+ sigFile := artifactFile + ".sig"
+ if err := os.WriteFile(artifactFile+".sig", signature, 0o600); err != nil {
+ return fmt.Errorf("write signature file (%s): %w", sigFile, err)
+ }
+
+ cmd.Printf("✅ Artifact signed successfully.\n")
+ cmd.Printf("Signature file: %s\n", sigFile)
+ return nil
+}
+
+func handleVerifyArtifact(cmd *cobra.Command, pubKeyFile, artifactFile, signatureFile string) error {
+ cmd.Println("🔍 Verifying artifact...")
+
+ // Read artifact public key
+ pubKeyPEM, err := os.ReadFile(pubKeyFile)
+ if err != nil {
+ return fmt.Errorf("read public key file: %w", err)
+ }
+
+ publicKey, err := reposign.ParseArtifactPubKey(pubKeyPEM)
+ if err != nil {
+ return fmt.Errorf("failed to parse artifact public key: %w", err)
+ }
+
+ // Read artifact data
+ artifactData, err := os.ReadFile(artifactFile)
+ if err != nil {
+ return fmt.Errorf("read artifact file: %w", err)
+ }
+
+ // Read signature
+ sigBytes, err := os.ReadFile(signatureFile)
+ if err != nil {
+ return fmt.Errorf("read signature file: %w", err)
+ }
+
+ signature, err := reposign.ParseSignature(sigBytes)
+ if err != nil {
+ return fmt.Errorf("failed to parse signature: %w", err)
+ }
+
+ // Validate artifact
+ if err := reposign.ValidateArtifact([]reposign.PublicKey{publicKey}, artifactData, *signature); err != nil {
+ return fmt.Errorf("artifact verification failed: %w", err)
+ }
+
+ cmd.Println("✅ Artifact signature is valid")
+ cmd.Printf("Artifact: %s\n", artifactFile)
+ cmd.Printf("Signed by key: %s\n", signature.KeyID)
+ cmd.Printf("Signature timestamp: %s\n", signature.Timestamp.Format("2006-01-02 15:04:05 MST"))
+ return nil
+}
+
+func handleVerifyArtifactKey(cmd *cobra.Command, artifactKeyFile, rootKeyFile, signatureFile, revocationFile string) error {
+ cmd.Println("🔍 Verifying artifact key...")
+
+ // Read artifact key data
+ artifactKeyData, err := os.ReadFile(artifactKeyFile)
+ if err != nil {
+ return fmt.Errorf("read artifact key file: %w", err)
+ }
+
+ // Read root public key(s)
+ rootKeyData, err := os.ReadFile(rootKeyFile)
+ if err != nil {
+ return fmt.Errorf("read root key file: %w", err)
+ }
+
+ rootPublicKeys, err := parseRootPublicKeys(rootKeyData)
+ if err != nil {
+ return fmt.Errorf("failed to parse root public key(s): %w", err)
+ }
+
+ // Read signature
+ sigBytes, err := os.ReadFile(signatureFile)
+ if err != nil {
+ return fmt.Errorf("read signature file: %w", err)
+ }
+
+ signature, err := reposign.ParseSignature(sigBytes)
+ if err != nil {
+ return fmt.Errorf("failed to parse signature: %w", err)
+ }
+
+ // Read optional revocation list
+ var revocationList *reposign.RevocationList
+ if revocationFile != "" {
+ revData, err := os.ReadFile(revocationFile)
+ if err != nil {
+ return fmt.Errorf("read revocation file: %w", err)
+ }
+
+ revocationList, err = reposign.ParseRevocationList(revData)
+ if err != nil {
+ return fmt.Errorf("failed to parse revocation list: %w", err)
+ }
+ }
+
+ // Validate artifact key(s)
+ validKeys, err := reposign.ValidateArtifactKeys(rootPublicKeys, artifactKeyData, *signature, revocationList)
+ if err != nil {
+ return fmt.Errorf("artifact key verification failed: %w", err)
+ }
+
+ cmd.Println("✅ Artifact key(s) verified successfully")
+ cmd.Printf("Signed by root key: %s\n", signature.KeyID)
+ cmd.Printf("Signature timestamp: %s\n", signature.Timestamp.Format("2006-01-02 15:04:05 MST"))
+ cmd.Printf("\nValid artifact keys (%d):\n", len(validKeys))
+ for i, key := range validKeys {
+ cmd.Printf(" [%d] Key ID: %s\n", i+1, key.Metadata.ID)
+ cmd.Printf(" Created: %s\n", key.Metadata.CreatedAt.Format("2006-01-02 15:04:05 MST"))
+ if !key.Metadata.ExpiresAt.IsZero() {
+ cmd.Printf(" Expires: %s\n", key.Metadata.ExpiresAt.Format("2006-01-02 15:04:05 MST"))
+ } else {
+ cmd.Printf(" Expires: Never\n")
+ }
+ }
+ return nil
+}
+
+// parseRootPublicKeys parses a root public key from PEM data
+func parseRootPublicKeys(data []byte) ([]reposign.PublicKey, error) {
+ key, err := reposign.ParseRootPublicKey(data)
+ if err != nil {
+ return nil, err
+ }
+ return []reposign.PublicKey{key}, nil
+}
diff --git a/client/cmd/signer/main.go b/client/cmd/signer/main.go
new file mode 100644
index 000000000..407093d07
--- /dev/null
+++ b/client/cmd/signer/main.go
@@ -0,0 +1,21 @@
+package main
+
+import (
+ "os"
+
+ "github.com/spf13/cobra"
+)
+
+var rootCmd = &cobra.Command{
+ Use: "signer",
+ Short: "A CLI tool for managing cryptographic keys and artifacts",
+ Long: `signer is a command-line tool that helps you manage
+root keys, artifact keys, and revocation lists securely.`,
+}
+
+func main() {
+ if err := rootCmd.Execute(); err != nil {
+ rootCmd.Println(err)
+ os.Exit(1)
+ }
+}
diff --git a/client/cmd/signer/revocation.go b/client/cmd/signer/revocation.go
new file mode 100644
index 000000000..1d84b65c3
--- /dev/null
+++ b/client/cmd/signer/revocation.go
@@ -0,0 +1,220 @@
+package main
+
+import (
+ "fmt"
+ "os"
+ "time"
+
+ "github.com/spf13/cobra"
+
+ "github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
+)
+
+const (
+ defaultRevocationListExpiration = 365 * 24 * time.Hour // 1 year
+)
+
+var (
+ keyID string
+ revocationListFile string
+ privateRootKeyFile string
+ publicRootKeyFile string
+ signatureFile string
+ expirationDuration time.Duration
+)
+
+var createRevocationListCmd = &cobra.Command{
+ Use: "create-revocation-list",
+ Short: "Create a new revocation list signed by the private root key",
+ SilenceUsage: true,
+ RunE: func(cmd *cobra.Command, args []string) error {
+ return handleCreateRevocationList(cmd, revocationListFile, privateRootKeyFile)
+ },
+}
+
+var extendRevocationListCmd = &cobra.Command{
+ Use: "extend-revocation-list",
+ Short: "Extend an existing revocation list with a given key ID",
+ SilenceUsage: true,
+ RunE: func(cmd *cobra.Command, args []string) error {
+ return handleExtendRevocationList(cmd, keyID, revocationListFile, privateRootKeyFile)
+ },
+}
+
+var verifyRevocationListCmd = &cobra.Command{
+ Use: "verify-revocation-list",
+ Short: "Verify a revocation list signature using the public root key",
+ SilenceUsage: true,
+ RunE: func(cmd *cobra.Command, args []string) error {
+ return handleVerifyRevocationList(cmd, revocationListFile, signatureFile, publicRootKeyFile)
+ },
+}
+
+func init() {
+ rootCmd.AddCommand(createRevocationListCmd)
+ rootCmd.AddCommand(extendRevocationListCmd)
+ rootCmd.AddCommand(verifyRevocationListCmd)
+
+ createRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the existing revocation list file")
+ createRevocationListCmd.Flags().StringVar(&privateRootKeyFile, "private-root-key", "", "Path to the private root key PEM file")
+ createRevocationListCmd.Flags().DurationVar(&expirationDuration, "expiration", defaultRevocationListExpiration, "Expiration duration for the revocation list (e.g., 8760h for 1 year)")
+ if err := createRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
+ panic(err)
+ }
+ if err := createRevocationListCmd.MarkFlagRequired("private-root-key"); err != nil {
+ panic(err)
+ }
+
+ extendRevocationListCmd.Flags().StringVar(&keyID, "key-id", "", "ID of the key to extend the revocation list for")
+ extendRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the existing revocation list file")
+ extendRevocationListCmd.Flags().StringVar(&privateRootKeyFile, "private-root-key", "", "Path to the private root key PEM file")
+ extendRevocationListCmd.Flags().DurationVar(&expirationDuration, "expiration", defaultRevocationListExpiration, "Expiration duration for the revocation list (e.g., 8760h for 1 year)")
+ if err := extendRevocationListCmd.MarkFlagRequired("key-id"); err != nil {
+ panic(err)
+ }
+ if err := extendRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
+ panic(err)
+ }
+ if err := extendRevocationListCmd.MarkFlagRequired("private-root-key"); err != nil {
+ panic(err)
+ }
+
+ verifyRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the revocation list file")
+ verifyRevocationListCmd.Flags().StringVar(&signatureFile, "signature-file", "", "Path to the signature file")
+ verifyRevocationListCmd.Flags().StringVar(&publicRootKeyFile, "public-root-key", "", "Path to the public root key PEM file")
+ if err := verifyRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
+ panic(err)
+ }
+ if err := verifyRevocationListCmd.MarkFlagRequired("signature-file"); err != nil {
+ panic(err)
+ }
+ if err := verifyRevocationListCmd.MarkFlagRequired("public-root-key"); err != nil {
+ panic(err)
+ }
+}
+
+func handleCreateRevocationList(cmd *cobra.Command, revocationListFile string, privateRootKeyFile string) error {
+ privKeyPEM, err := os.ReadFile(privateRootKeyFile)
+ if err != nil {
+ return fmt.Errorf("failed to read private root key file: %w", err)
+ }
+
+ privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
+ if err != nil {
+ return fmt.Errorf("failed to parse private root key: %w", err)
+ }
+
+ rlBytes, sigBytes, err := reposign.CreateRevocationList(*privateRootKey, expirationDuration)
+ if err != nil {
+ return fmt.Errorf("failed to create revocation list: %w", err)
+ }
+
+ if err := writeOutputFiles(revocationListFile, revocationListFile+".sig", rlBytes, sigBytes); err != nil {
+ return fmt.Errorf("failed to write output files: %w", err)
+ }
+
+ cmd.Println("✅ Revocation list created successfully")
+ return nil
+}
+
+func handleExtendRevocationList(cmd *cobra.Command, keyID, revocationListFile, privateRootKeyFile string) error {
+ privKeyPEM, err := os.ReadFile(privateRootKeyFile)
+ if err != nil {
+ return fmt.Errorf("failed to read private root key file: %w", err)
+ }
+
+ privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
+ if err != nil {
+ return fmt.Errorf("failed to parse private root key: %w", err)
+ }
+
+ rlBytes, err := os.ReadFile(revocationListFile)
+ if err != nil {
+ return fmt.Errorf("failed to read revocation list file: %w", err)
+ }
+
+ rl, err := reposign.ParseRevocationList(rlBytes)
+ if err != nil {
+ return fmt.Errorf("failed to parse revocation list: %w", err)
+ }
+
+ kid, err := reposign.ParseKeyID(keyID)
+ if err != nil {
+ return fmt.Errorf("invalid key ID: %w", err)
+ }
+
+ newRLBytes, sigBytes, err := reposign.ExtendRevocationList(*privateRootKey, *rl, kid, expirationDuration)
+ if err != nil {
+ return fmt.Errorf("failed to extend revocation list: %w", err)
+ }
+
+ if err := writeOutputFiles(revocationListFile, revocationListFile+".sig", newRLBytes, sigBytes); err != nil {
+ return fmt.Errorf("failed to write output files: %w", err)
+ }
+
+ cmd.Println("✅ Revocation list extended successfully")
+ return nil
+}
+
+func handleVerifyRevocationList(cmd *cobra.Command, revocationListFile, signatureFile, publicRootKeyFile string) error {
+ // Read revocation list file
+ rlBytes, err := os.ReadFile(revocationListFile)
+ if err != nil {
+ return fmt.Errorf("failed to read revocation list file: %w", err)
+ }
+
+ // Read signature file
+ sigBytes, err := os.ReadFile(signatureFile)
+ if err != nil {
+ return fmt.Errorf("failed to read signature file: %w", err)
+ }
+
+ // Read public root key file
+ pubKeyPEM, err := os.ReadFile(publicRootKeyFile)
+ if err != nil {
+ return fmt.Errorf("failed to read public root key file: %w", err)
+ }
+
+ // Parse public root key
+ publicKey, err := reposign.ParseRootPublicKey(pubKeyPEM)
+ if err != nil {
+ return fmt.Errorf("failed to parse public root key: %w", err)
+ }
+
+ // Parse signature
+ signature, err := reposign.ParseSignature(sigBytes)
+ if err != nil {
+ return fmt.Errorf("failed to parse signature: %w", err)
+ }
+
+ // Validate revocation list
+ rl, err := reposign.ValidateRevocationList([]reposign.PublicKey{publicKey}, rlBytes, *signature)
+ if err != nil {
+ return fmt.Errorf("failed to validate revocation list: %w", err)
+ }
+
+ // Display results
+ cmd.Println("✅ Revocation list signature is valid")
+ cmd.Printf("Last Updated: %s\n", rl.LastUpdated.Format(time.RFC3339))
+ cmd.Printf("Expires At: %s\n", rl.ExpiresAt.Format(time.RFC3339))
+ cmd.Printf("Number of revoked keys: %d\n", len(rl.Revoked))
+
+ if len(rl.Revoked) > 0 {
+ cmd.Println("\nRevoked Keys:")
+ for keyID, revokedTime := range rl.Revoked {
+ cmd.Printf(" - %s (revoked at: %s)\n", keyID, revokedTime.Format(time.RFC3339))
+ }
+ }
+
+ return nil
+}
+
+func writeOutputFiles(rlPath, sigPath string, rlBytes, sigBytes []byte) error {
+ if err := os.WriteFile(rlPath, rlBytes, 0o600); err != nil {
+ return fmt.Errorf("failed to write revocation list file: %w", err)
+ }
+ if err := os.WriteFile(sigPath, sigBytes, 0o600); err != nil {
+ return fmt.Errorf("failed to write signature file: %w", err)
+ }
+ return nil
+}
diff --git a/client/cmd/signer/rootkey.go b/client/cmd/signer/rootkey.go
new file mode 100644
index 000000000..78ac36b41
--- /dev/null
+++ b/client/cmd/signer/rootkey.go
@@ -0,0 +1,74 @@
+package main
+
+import (
+ "fmt"
+ "os"
+ "time"
+
+ "github.com/spf13/cobra"
+
+ "github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
+)
+
+var (
+ privKeyFile string
+ pubKeyFile string
+ rootExpiration time.Duration
+)
+
+var createRootKeyCmd = &cobra.Command{
+ Use: "create-root-key",
+ Short: "Create a new root key pair",
+ Long: `Create a new root key pair and specify an expiration time for it.`,
+ SilenceUsage: true,
+ RunE: func(cmd *cobra.Command, args []string) error {
+ // Validate expiration
+ if rootExpiration <= 0 {
+ return fmt.Errorf("--expiration must be a positive duration (e.g., 720h, 365d, 8760h)")
+ }
+
+ // Run main logic
+ if err := handleGenerateRootKey(cmd, privKeyFile, pubKeyFile, rootExpiration); err != nil {
+ return fmt.Errorf("failed to generate root key: %w", err)
+ }
+ return nil
+ },
+}
+
+func init() {
+ rootCmd.AddCommand(createRootKeyCmd)
+ createRootKeyCmd.Flags().StringVar(&privKeyFile, "priv-key-file", "", "Path to output private key file")
+ createRootKeyCmd.Flags().StringVar(&pubKeyFile, "pub-key-file", "", "Path to output public key file")
+ createRootKeyCmd.Flags().DurationVar(&rootExpiration, "expiration", 0, "Expiration time for the root key (e.g., 720h,)")
+
+ if err := createRootKeyCmd.MarkFlagRequired("priv-key-file"); err != nil {
+ panic(err)
+ }
+ if err := createRootKeyCmd.MarkFlagRequired("pub-key-file"); err != nil {
+ panic(err)
+ }
+ if err := createRootKeyCmd.MarkFlagRequired("expiration"); err != nil {
+ panic(err)
+ }
+}
+
+func handleGenerateRootKey(cmd *cobra.Command, privKeyFile, pubKeyFile string, expiration time.Duration) error {
+ rk, privPEM, pubPEM, err := reposign.GenerateRootKey(expiration)
+ if err != nil {
+ return fmt.Errorf("generate root key: %w", err)
+ }
+
+ // Write private key
+ if err := os.WriteFile(privKeyFile, privPEM, 0o600); err != nil {
+ return fmt.Errorf("write private key file (%s): %w", privKeyFile, err)
+ }
+
+ // Write public key
+ if err := os.WriteFile(pubKeyFile, pubPEM, 0o600); err != nil {
+ return fmt.Errorf("write public key file (%s): %w", pubKeyFile, err)
+ }
+
+ cmd.Printf("%s\n\n", rk.String())
+ cmd.Printf("✅ Root key pair generated successfully.\n")
+ return nil
+}
diff --git a/client/cmd/ssh.go b/client/cmd/ssh.go
index 525bcdef1..0acf0b133 100644
--- a/client/cmd/ssh.go
+++ b/client/cmd/ssh.go
@@ -634,7 +634,11 @@ func parseAndStartLocalForward(ctx context.Context, c *sshclient.Client, forward
return err
}
- cmd.Printf("Local port forwarding: %s -> %s\n", localAddr, remoteAddr)
+ if err := validateDestinationPort(remoteAddr); err != nil {
+ return fmt.Errorf("invalid remote address: %w", err)
+ }
+
+ log.Debugf("Local port forwarding: %s -> %s", localAddr, remoteAddr)
go func() {
if err := c.LocalPortForward(ctx, localAddr, remoteAddr); err != nil && !errors.Is(err, context.Canceled) {
@@ -652,7 +656,11 @@ func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forwar
return err
}
- cmd.Printf("Remote port forwarding: %s -> %s\n", remoteAddr, localAddr)
+ if err := validateDestinationPort(localAddr); err != nil {
+ return fmt.Errorf("invalid local address: %w", err)
+ }
+
+ log.Debugf("Remote port forwarding: %s -> %s", remoteAddr, localAddr)
go func() {
if err := c.RemotePortForward(ctx, remoteAddr, localAddr); err != nil && !errors.Is(err, context.Canceled) {
@@ -663,6 +671,35 @@ func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forwar
return nil
}
+// validateDestinationPort checks that the destination address has a valid port.
+// Port 0 is only valid for bind addresses (where the OS picks an available port),
+// not for destination addresses where we need to connect.
+func validateDestinationPort(addr string) error {
+ if strings.HasPrefix(addr, "/") || strings.HasPrefix(addr, "./") {
+ return nil
+ }
+
+ _, portStr, err := net.SplitHostPort(addr)
+ if err != nil {
+ return fmt.Errorf("parse address %s: %w", addr, err)
+ }
+
+ port, err := strconv.Atoi(portStr)
+ if err != nil {
+ return fmt.Errorf("invalid port %s: %w", portStr, err)
+ }
+
+ if port == 0 {
+ return fmt.Errorf("port 0 is not valid for destination address")
+ }
+
+ if port < 0 || port > 65535 {
+ return fmt.Errorf("port %d out of range (1-65535)", port)
+ }
+
+ return nil
+}
+
// parsePortForwardSpec parses port forward specifications like "8080:localhost:80" or "[::1]:8080:localhost:80".
// Also supports Unix sockets like "8080:/tmp/socket" or "127.0.0.1:8080:/tmp/socket".
func parsePortForwardSpec(spec string) (string, string, error) {
diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go
index b9ff35945..888a9a3f7 100644
--- a/client/cmd/testutil_test.go
+++ b/client/cmd/testutil_test.go
@@ -127,7 +127,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
if err != nil {
t.Fatal(err)
}
- mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController)
+ mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil)
if err != nil {
t.Fatal(err)
}
diff --git a/client/cmd/up.go b/client/cmd/up.go
index 140ba2cb2..9efc2e60d 100644
--- a/client/cmd/up.go
+++ b/client/cmd/up.go
@@ -197,7 +197,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
r := peer.NewRecorder(config.ManagementURL.String())
r.GetFullStatus()
- connectClient := internal.NewConnectClient(ctx, config, r)
+ connectClient := internal.NewConnectClient(ctx, config, r, false)
SetupDebugHandler(ctx, config, r, connectClient, "")
return connectClient.Run(nil)
diff --git a/client/cmd/update.go b/client/cmd/update.go
new file mode 100644
index 000000000..dc49b02c3
--- /dev/null
+++ b/client/cmd/update.go
@@ -0,0 +1,13 @@
+//go:build !windows && !darwin
+
+package cmd
+
+import (
+ "github.com/spf13/cobra"
+)
+
+var updateCmd *cobra.Command
+
+func isUpdateBinary() bool {
+ return false
+}
diff --git a/client/cmd/update_supported.go b/client/cmd/update_supported.go
new file mode 100644
index 000000000..977875093
--- /dev/null
+++ b/client/cmd/update_supported.go
@@ -0,0 +1,75 @@
+//go:build windows || darwin
+
+package cmd
+
+import (
+ "context"
+ "os"
+ "path/filepath"
+ "strings"
+
+ log "github.com/sirupsen/logrus"
+ "github.com/spf13/cobra"
+
+ "github.com/netbirdio/netbird/client/internal/updatemanager/installer"
+ "github.com/netbirdio/netbird/util"
+)
+
+var (
+ updateCmd = &cobra.Command{
+ Use: "update",
+ Short: "Update the NetBird client application",
+ RunE: updateFunc,
+ }
+
+ tempDirFlag string
+ installerFile string
+ serviceDirFlag string
+ dryRunFlag bool
+)
+
+func init() {
+ updateCmd.Flags().StringVar(&tempDirFlag, "temp-dir", "", "temporary dir")
+ updateCmd.Flags().StringVar(&installerFile, "installer-file", "", "installer file")
+ updateCmd.Flags().StringVar(&serviceDirFlag, "service-dir", "", "service directory")
+ updateCmd.Flags().BoolVar(&dryRunFlag, "dry-run", false, "dry run the update process without making any changes")
+}
+
+// isUpdateBinary checks if the current executable is named "update" or "update.exe"
+func isUpdateBinary() bool {
+ // Remove extension for cross-platform compatibility
+ execPath, err := os.Executable()
+ if err != nil {
+ return false
+ }
+ baseName := filepath.Base(execPath)
+ name := strings.TrimSuffix(baseName, filepath.Ext(baseName))
+
+ return name == installer.UpdaterBinaryNameWithoutExtension()
+}
+
+func updateFunc(cmd *cobra.Command, args []string) error {
+ if err := setupLogToFile(tempDirFlag); err != nil {
+ return err
+ }
+
+ log.Infof("updater started: %s", serviceDirFlag)
+ updater := installer.NewWithDir(tempDirFlag)
+ if err := updater.Setup(context.Background(), dryRunFlag, installerFile, serviceDirFlag); err != nil {
+ log.Errorf("failed to update application: %v", err)
+ return err
+ }
+ return nil
+}
+
+func setupLogToFile(dir string) error {
+ logFile := filepath.Join(dir, installer.LogFile)
+
+ if _, err := os.Stat(logFile); err == nil {
+ if err := os.Remove(logFile); err != nil {
+ log.Errorf("failed to remove existing log file: %v\n", err)
+ }
+ }
+
+ return util.InitLog(logLevel, util.LogConsole, logFile)
+}
diff --git a/client/embed/embed.go b/client/embed/embed.go
index 3090ca6a2..353c5438f 100644
--- a/client/embed/embed.go
+++ b/client/embed/embed.go
@@ -173,7 +173,7 @@ func (c *Client) Start(startCtx context.Context) error {
}
recorder := peer.NewRecorder(c.config.ManagementURL.String())
- client := internal.NewConnectClient(ctx, c.config, recorder)
+ client := internal.NewConnectClient(ctx, c.config, recorder, false)
// either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available
diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go
index adec802c8..6b29c5606 100644
--- a/client/firewall/nftables/manager_linux_test.go
+++ b/client/firewall/nftables/manager_linux_test.go
@@ -386,6 +386,97 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
verifyIptablesOutput(t, stdout, stderr)
}
+func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
+ if check() != NFTABLES {
+ t.Skip("nftables not supported on this system")
+ }
+
+ if _, err := exec.LookPath("iptables-save"); err != nil {
+ t.Skipf("iptables-save not available on this system: %v", err)
+ }
+
+ // First ensure iptables-nft tables exist by running iptables-save
+ stdout, stderr := runIptablesSave(t)
+ verifyIptablesOutput(t, stdout, stderr)
+
+ manager, err := Create(ifaceMock, iface.DefaultMTU)
+ require.NoError(t, err, "failed to create manager")
+ require.NoError(t, manager.Init(nil))
+
+ t.Cleanup(func() {
+ err := manager.Close(nil)
+ require.NoError(t, err, "failed to reset manager state")
+
+ // Verify iptables output after reset
+ stdout, stderr := runIptablesSave(t)
+ verifyIptablesOutput(t, stdout, stderr)
+ })
+
+ const octet2Count = 25
+ const octet3Count = 255
+ prefixes := make([]netip.Prefix, 0, (octet2Count-1)*(octet3Count-1))
+ for i := 1; i < octet2Count; i++ {
+ for j := 1; j < octet3Count; j++ {
+ addr := netip.AddrFrom4([4]byte{192, byte(j), byte(i), 0})
+ prefixes = append(prefixes, netip.PrefixFrom(addr, 24))
+ }
+ }
+ _, err = manager.AddRouteFiltering(
+ nil,
+ prefixes,
+ fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
+ fw.ProtocolTCP,
+ nil,
+ &fw.Port{Values: []uint16{443}},
+ fw.ActionAccept,
+ )
+ require.NoError(t, err, "failed to add route filtering rule")
+
+ stdout, stderr = runIptablesSave(t)
+ verifyIptablesOutput(t, stdout, stderr)
+}
+
+func TestNftablesManagerCompatibilityWithIptablesForEmptyPrefixes(t *testing.T) {
+ if check() != NFTABLES {
+ t.Skip("nftables not supported on this system")
+ }
+
+ if _, err := exec.LookPath("iptables-save"); err != nil {
+ t.Skipf("iptables-save not available on this system: %v", err)
+ }
+
+ // First ensure iptables-nft tables exist by running iptables-save
+ stdout, stderr := runIptablesSave(t)
+ verifyIptablesOutput(t, stdout, stderr)
+
+ manager, err := Create(ifaceMock, iface.DefaultMTU)
+ require.NoError(t, err, "failed to create manager")
+ require.NoError(t, manager.Init(nil))
+
+ t.Cleanup(func() {
+ err := manager.Close(nil)
+ require.NoError(t, err, "failed to reset manager state")
+
+ // Verify iptables output after reset
+ stdout, stderr := runIptablesSave(t)
+ verifyIptablesOutput(t, stdout, stderr)
+ })
+
+ _, err = manager.AddRouteFiltering(
+ nil,
+ []netip.Prefix{},
+ fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
+ fw.ProtocolTCP,
+ nil,
+ &fw.Port{Values: []uint16{443}},
+ fw.ActionAccept,
+ )
+ require.NoError(t, err, "failed to add route filtering rule")
+
+ stdout, stderr = runIptablesSave(t)
+ verifyIptablesOutput(t, stdout, stderr)
+}
+
func compareExprsIgnoringCounters(t *testing.T, got, want []expr.Any) {
t.Helper()
require.Equal(t, len(got), len(want), "expression count mismatch")
diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go
index 7f95992da..b6e0cf5b2 100644
--- a/client/firewall/nftables/router_linux.go
+++ b/client/firewall/nftables/router_linux.go
@@ -48,9 +48,11 @@ const (
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40
-)
-const refreshRulesMapError = "refresh rules map: %w"
+ // maxPrefixesSet 1638 prefixes start to fail, taking some margin
+ maxPrefixesSet = 1500
+ refreshRulesMapError = "refresh rules map: %w"
+)
var (
errFilterTableNotFound = fmt.Errorf("'filter' table not found")
@@ -513,16 +515,35 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err
}
elements := convertPrefixesToSet(prefixes)
- if err := r.conn.AddSet(nfset, elements); err != nil {
- return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
- }
+ nElements := len(elements)
+ maxElements := maxPrefixesSet * 2
+ initialElements := elements[:min(maxElements, nElements)]
+
+ if err := r.conn.AddSet(nfset, initialElements); err != nil {
+ return nil, fmt.Errorf("error adding set %s: %w", setName, err)
+ }
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err)
}
+ log.Debugf("Created new ipset: %s with %d initial prefixes (total prefixes %d)", setName, len(initialElements)/2, len(prefixes))
- log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
+ var subEnd int
+ for subStart := maxElements; subStart < nElements; subStart += maxElements {
+ subEnd = min(subStart+maxElements, nElements)
+ subElement := elements[subStart:subEnd]
+ nSubPrefixes := len(subElement) / 2
+ log.Tracef("Adding new prefixes (%d) in ipset: %s", nSubPrefixes, setName)
+ if err := r.conn.SetAddElements(nfset, subElement); err != nil {
+ return nil, fmt.Errorf("error adding prefixes (%d) to set %s: %w", nSubPrefixes, setName, err)
+ }
+ if err := r.conn.Flush(); err != nil {
+ return nil, fmt.Errorf("flush error: %w", err)
+ }
+ log.Debugf("Added new prefixes (%d) in ipset: %s", nSubPrefixes, setName)
+ }
+ log.Infof("Created new ipset: %s with %d prefixes", setName, len(prefixes))
return nfset, nil
}
diff --git a/client/iface/device/device_ios.go b/client/iface/device/device_ios.go
index f96edf992..d841ac2fe 100644
--- a/client/iface/device/device_ios.go
+++ b/client/iface/device/device_ios.go
@@ -4,6 +4,7 @@
package device
import (
+ "fmt"
"os"
log "github.com/sirupsen/logrus"
@@ -45,10 +46,31 @@ func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu
}
}
+// ErrInvalidTunnelFD is returned when the tunnel file descriptor is invalid (0).
+// This typically means the Swift code couldn't find the utun control socket.
+var ErrInvalidTunnelFD = fmt.Errorf("invalid tunnel file descriptor: fd is 0 (Swift failed to locate utun socket)")
+
func (t *TunDevice) Create() (WGConfigurer, error) {
log.Infof("create tun interface")
- dupTunFd, err := unix.Dup(t.tunFd)
+ var tunDevice tun.Device
+ var err error
+
+ // Validate the tunnel file descriptor.
+ // On iOS/tvOS, the FD must be provided by the NEPacketTunnelProvider.
+ // A value of 0 means the Swift code couldn't find the utun control socket
+ // (the low-level APIs like ctl_info, sockaddr_ctl may not be exposed in
+ // tvOS SDK headers). This is a hard error - there's no viable fallback
+ // since tun.CreateTUN() cannot work within the iOS/tvOS sandbox.
+ if t.tunFd == 0 {
+ log.Errorf("Tunnel file descriptor is 0 - Swift code failed to locate the utun control socket. " +
+ "On tvOS, ensure the NEPacketTunnelProvider is properly configured and the tunnel is started.")
+ return nil, ErrInvalidTunnelFD
+ }
+
+ // Normal iOS/tvOS path: use the provided file descriptor from NEPacketTunnelProvider
+ var dupTunFd int
+ dupTunFd, err = unix.Dup(t.tunFd)
if err != nil {
log.Errorf("Unable to dup tun fd: %v", err)
return nil, err
@@ -60,7 +82,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
_ = unix.Close(dupTunFd)
return nil, err
}
- tunDevice, err := tun.CreateTUNFromFile(os.NewFile(uintptr(dupTunFd), "/dev/tun"), 0)
+ tunDevice, err = tun.CreateTUNFromFile(os.NewFile(uintptr(dupTunFd), "/dev/tun"), 0)
if err != nil {
log.Errorf("Unable to create new tun device from fd: %v", err)
_ = unix.Close(dupTunFd)
diff --git a/client/iface/wgproxy/factory_kernel.go b/client/iface/wgproxy/factory_kernel.go
index ad2807546..2714c5774 100644
--- a/client/iface/wgproxy/factory_kernel.go
+++ b/client/iface/wgproxy/factory_kernel.go
@@ -3,12 +3,19 @@
package wgproxy
import (
+ "os"
+ "strconv"
+
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
)
+const (
+ envDisableEBPFWGProxy = "NB_DISABLE_EBPF_WG_PROXY"
+)
+
type KernelFactory struct {
wgPort int
mtu uint16
@@ -22,6 +29,12 @@ func NewKernelFactory(wgPort int, mtu uint16) *KernelFactory {
mtu: mtu,
}
+ if isEBPFDisabled() {
+ log.Infof("WireGuard Proxy Factory will produce UDP proxy")
+ log.Infof("eBPF WireGuard proxy is disabled via %s environment variable", envDisableEBPFWGProxy)
+ return f
+ }
+
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, mtu)
if err := ebpfProxy.Listen(); err != nil {
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
@@ -47,3 +60,16 @@ func (w *KernelFactory) Free() error {
}
return w.ebpfProxy.Free()
}
+
+func isEBPFDisabled() bool {
+ val := os.Getenv(envDisableEBPFWGProxy)
+ if val == "" {
+ return false
+ }
+ disabled, err := strconv.ParseBool(val)
+ if err != nil {
+ log.Warnf("failed to parse %s: %v", envDisableEBPFWGProxy, err)
+ return false
+ }
+ return disabled
+}
diff --git a/client/internal/connect.go b/client/internal/connect.go
index e9d422a28..017c8bf10 100644
--- a/client/internal/connect.go
+++ b/client/internal/connect.go
@@ -24,10 +24,14 @@ import (
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
+ "github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
+ "github.com/netbirdio/netbird/client/internal/updatemanager"
+ "github.com/netbirdio/netbird/client/internal/updatemanager/installer"
nbnet "github.com/netbirdio/netbird/client/net"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/ssh"
+ sshconfig "github.com/netbirdio/netbird/client/ssh/config"
"github.com/netbirdio/netbird/client/system"
mgm "github.com/netbirdio/netbird/shared/management/client"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
@@ -39,11 +43,13 @@ import (
)
type ConnectClient struct {
- ctx context.Context
- config *profilemanager.Config
- statusRecorder *peer.Status
- engine *Engine
- engineMutex sync.Mutex
+ ctx context.Context
+ config *profilemanager.Config
+ statusRecorder *peer.Status
+ doInitialAutoUpdate bool
+
+ engine *Engine
+ engineMutex sync.Mutex
persistSyncResponse bool
}
@@ -52,13 +58,15 @@ func NewConnectClient(
ctx context.Context,
config *profilemanager.Config,
statusRecorder *peer.Status,
+ doInitalAutoUpdate bool,
) *ConnectClient {
return &ConnectClient{
- ctx: ctx,
- config: config,
- statusRecorder: statusRecorder,
- engineMutex: sync.Mutex{},
+ ctx: ctx,
+ config: config,
+ statusRecorder: statusRecorder,
+ doInitialAutoUpdate: doInitalAutoUpdate,
+ engineMutex: sync.Mutex{},
}
}
@@ -162,6 +170,33 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
return err
}
+ var path string
+ if runtime.GOOS == "ios" || runtime.GOOS == "android" {
+ // On mobile, use the provided state file path directly
+ if !fileExists(mobileDependency.StateFilePath) {
+ if err := createFile(mobileDependency.StateFilePath); err != nil {
+ log.Errorf("failed to create state file: %v", err)
+ // we are not exiting as we can run without the state manager
+ }
+ }
+ path = mobileDependency.StateFilePath
+ } else {
+ sm := profilemanager.NewServiceManager("")
+ path = sm.GetStatePath()
+ }
+ stateManager := statemanager.New(path)
+ stateManager.RegisterState(&sshconfig.ShutdownState{})
+
+ updateManager, err := updatemanager.NewManager(c.statusRecorder, stateManager)
+ if err == nil {
+ updateManager.CheckUpdateSuccess(c.ctx)
+
+ inst := installer.New()
+ if err := inst.CleanUpInstallerFiles(); err != nil {
+ log.Errorf("failed to clean up temporary installer file: %v", err)
+ }
+ }
+
defer c.statusRecorder.ClientStop()
operation := func() error {
// if context cancelled we not start new backoff cycle
@@ -273,7 +308,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
checks := loginResp.GetChecks()
c.engineMutex.Lock()
- engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
+ engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, stateManager)
engine.SetSyncResponsePersistence(c.persistSyncResponse)
c.engine = engine
c.engineMutex.Unlock()
@@ -283,6 +318,15 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
return wrapErr(err)
}
+ if loginResp.PeerConfig != nil && loginResp.PeerConfig.AutoUpdate != nil {
+ // AutoUpdate will be true when the user click on "Connect" menu on the UI
+ if c.doInitialAutoUpdate {
+ log.Infof("start engine by ui, run auto-update check")
+ c.engine.InitialUpdateHandling(loginResp.PeerConfig.AutoUpdate)
+ c.doInitialAutoUpdate = false
+ }
+ }
+
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected)
diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go
index 3c201ecfc..01a0377a5 100644
--- a/client/internal/debug/debug.go
+++ b/client/internal/debug/debug.go
@@ -27,6 +27,7 @@ import (
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
+ "github.com/netbirdio/netbird/client/internal/updatemanager/installer"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
)
@@ -362,6 +363,10 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add systemd logs: %v", err)
}
+ if err := g.addUpdateLogs(); err != nil {
+ log.Errorf("failed to add updater logs: %v", err)
+ }
+
return nil
}
@@ -650,6 +655,29 @@ func (g *BundleGenerator) addStateFile() error {
return nil
}
+func (g *BundleGenerator) addUpdateLogs() error {
+ inst := installer.New()
+ logFiles := inst.LogFiles()
+ if len(logFiles) == 0 {
+ return nil
+ }
+
+ log.Infof("adding updater logs")
+ for _, logFile := range logFiles {
+ data, err := os.ReadFile(logFile)
+ if err != nil {
+ log.Warnf("failed to read update log file %s: %v", logFile, err)
+ continue
+ }
+
+ baseName := filepath.Base(logFile)
+ if err := g.addFileToZip(bytes.NewReader(data), filepath.Join("update-logs", baseName)); err != nil {
+ return fmt.Errorf("add update log file %s to zip: %w", baseName, err)
+ }
+ }
+ return nil
+}
+
func (g *BundleGenerator) addCorruptedStateFiles() error {
sm := profilemanager.NewServiceManager("")
pattern := sm.GetStatePath()
diff --git a/client/internal/dns/mgmt/mgmt.go b/client/internal/dns/mgmt/mgmt.go
index 290395473..d01be0c2c 100644
--- a/client/internal/dns/mgmt/mgmt.go
+++ b/client/internal/dns/mgmt/mgmt.go
@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net"
+ "net/netip"
"net/url"
"strings"
"sync"
@@ -26,6 +27,11 @@ type Resolver struct {
mutex sync.RWMutex
}
+type ipsResponse struct {
+ ips []netip.Addr
+ err error
+}
+
// NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver {
return &Resolver{
@@ -99,9 +105,9 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
defer cancel()
- ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
+ ips, err := lookupIPWithExtraTimeout(ctx, d)
if err != nil {
- return fmt.Errorf("resolve domain %s: %w", d.SafeString(), err)
+ return err
}
var aRecords, aaaaRecords []dns.RR
@@ -159,6 +165,36 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
return nil
}
+func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) {
+ log.Infof("looking up IP for mgmt domain=%s", d.SafeString())
+ defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString())
+ resultChan := make(chan *ipsResponse, 1)
+
+ go func() {
+ ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
+ resultChan <- &ipsResponse{
+ err: err,
+ ips: ips,
+ }
+ }()
+
+ var resp *ipsResponse
+
+ select {
+ case <-time.After(dnsTimeout + time.Millisecond*500):
+ log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString())
+ return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString())
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case resp = <-resultChan:
+ }
+
+ if resp.err != nil {
+ return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err)
+ }
+ return resp.ips, nil
+}
+
// PopulateFromConfig extracts and caches domains from the client configuration.
func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) error {
if mgmtURL == nil {
diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go
index afaf0579f..94945b55a 100644
--- a/client/internal/dns/server.go
+++ b/client/internal/dns/server.go
@@ -80,6 +80,7 @@ type DefaultServer struct {
updateSerial uint64
previousConfigHash uint64
currentConfig HostDNSConfig
+ currentConfigHash uint64
handlerChain *HandlerChain
extraDomains map[domain.Domain]int
@@ -207,6 +208,7 @@ func newDefaultServer(
hostsDNSHolder: newHostsDNSHolder(),
hostManager: &noopHostConfigurator{},
mgmtCacheResolver: mgmtCacheResolver,
+ currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied
}
// register with root zone, handler chain takes care of the routing
@@ -586,8 +588,29 @@ func (s *DefaultServer) applyHostConfig() {
log.Debugf("extra match domains: %v", maps.Keys(s.extraDomains))
+ hash, err := hashstructure.Hash(config, hashstructure.FormatV2, &hashstructure.HashOptions{
+ ZeroNil: true,
+ IgnoreZeroValue: true,
+ SlicesAsSets: true,
+ UseStringer: true,
+ })
+ if err != nil {
+ log.Warnf("unable to hash the host dns configuration, will apply config anyway: %s", err)
+ // Fall through to apply config anyway (fail-safe approach)
+ } else if s.currentConfigHash == hash {
+ log.Debugf("not applying host config as there are no changes")
+ return
+ }
+
+ log.Debugf("applying host config as there are changes")
if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil {
log.Errorf("failed to apply DNS host manager update: %v", err)
+ return
+ }
+
+ // Only update hash if it was computed successfully and config was applied
+ if err == nil {
+ s.currentConfigHash = hash
}
s.registerFallback(config)
diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go
index d12070128..fe1f67f66 100644
--- a/client/internal/dns/server_test.go
+++ b/client/internal/dns/server_test.go
@@ -1602,7 +1602,10 @@ func TestExtraDomains(t *testing.T) {
"other.example.com.",
"duplicate.example.com.",
},
- applyHostConfigCall: 4,
+ // Expect 3 calls instead of 4 because when deregistering duplicate.example.com,
+ // the domain remains in the config (ref count goes from 2 to 1), so the host
+ // config hash doesn't change and applyDNSConfig is not called.
+ applyHostConfigCall: 3,
},
{
name: "Config update with new domains after registration",
@@ -1657,7 +1660,10 @@ func TestExtraDomains(t *testing.T) {
expectedMatchOnly: []string{
"extra.example.com.",
},
- applyHostConfigCall: 3,
+ // Expect 2 calls instead of 3 because when deregistering protected.example.com,
+ // it's removed from extraDomains but still remains in the config (from customZones),
+ // so the host config hash doesn't change and applyDNSConfig is not called.
+ applyHostConfigCall: 2,
},
{
name: "Register domain that is part of nameserver group",
diff --git a/client/internal/engine.go b/client/internal/engine.go
index ff1cec19a..4f18c3bc8 100644
--- a/client/internal/engine.go
+++ b/client/internal/engine.go
@@ -42,14 +42,13 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peerstore"
- "github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
+ "github.com/netbirdio/netbird/client/internal/updatemanager"
cProto "github.com/netbirdio/netbird/client/proto"
- sshconfig "github.com/netbirdio/netbird/client/ssh/config"
"github.com/netbirdio/netbird/shared/management/domain"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
@@ -73,6 +72,7 @@ const (
PeerConnectionTimeoutMax = 45000 // ms
PeerConnectionTimeoutMin = 30000 // ms
connInitLimit = 200
+ disableAutoUpdate = "disabled"
)
var ErrResetConnection = fmt.Errorf("reset connection")
@@ -201,6 +201,9 @@ type Engine struct {
connSemaphore *semaphoregroup.SemaphoreGroup
flowManager nftypes.FlowManager
+ // auto-update
+ updateManager *updatemanager.Manager
+
// WireGuard interface monitor
wgIfaceMonitor *WGIfaceMonitor
@@ -221,17 +224,7 @@ type localIpUpdater interface {
}
// NewEngine creates a new Connection Engine with probes attached
-func NewEngine(
- clientCtx context.Context,
- clientCancel context.CancelFunc,
- signalClient signal.Client,
- mgmClient mgm.Client,
- relayManager *relayClient.Manager,
- config *EngineConfig,
- mobileDep MobileDependency,
- statusRecorder *peer.Status,
- checks []*mgmProto.Checks,
-) *Engine {
+func NewEngine(clientCtx context.Context, clientCancel context.CancelFunc, signalClient signal.Client, mgmClient mgm.Client, relayManager *relayClient.Manager, config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status, checks []*mgmProto.Checks, stateManager *statemanager.Manager) *Engine {
engine := &Engine{
clientCtx: clientCtx,
clientCancel: clientCancel,
@@ -247,28 +240,12 @@ func NewEngine(
TURNs: []*stun.URI{},
networkSerial: 0,
statusRecorder: statusRecorder,
+ stateManager: stateManager,
checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
}
- sm := profilemanager.NewServiceManager("")
-
- path := sm.GetStatePath()
- if runtime.GOOS == "ios" || runtime.GOOS == "android" {
- if !fileExists(mobileDep.StateFilePath) {
- err := createFile(mobileDep.StateFilePath)
- if err != nil {
- log.Errorf("failed to create state file: %v", err)
- // we are not exiting as we can run without the state manager
- }
- }
-
- path = mobileDep.StateFilePath
- }
- engine.stateManager = statemanager.New(path)
- engine.stateManager.RegisterState(&sshconfig.ShutdownState{})
-
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
return engine
}
@@ -308,6 +285,10 @@ func (e *Engine) Stop() error {
e.srWatcher.Close()
}
+ if e.updateManager != nil {
+ e.updateManager.Stop()
+ }
+
log.Info("cleaning up status recorder states")
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
@@ -541,6 +522,13 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
return nil
}
+func (e *Engine) InitialUpdateHandling(autoUpdateSettings *mgmProto.AutoUpdateSettings) {
+ e.syncMsgMux.Lock()
+ defer e.syncMsgMux.Unlock()
+
+ e.handleAutoUpdateVersion(autoUpdateSettings, true)
+}
+
func (e *Engine) createFirewall() error {
if e.config.DisableFirewall {
log.Infof("firewall is disabled")
@@ -749,6 +737,41 @@ func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mg
return nil
}
+func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings, initialCheck bool) {
+ if autoUpdateSettings == nil {
+ return
+ }
+
+ disabled := autoUpdateSettings.Version == disableAutoUpdate
+
+ // Stop and cleanup if disabled
+ if e.updateManager != nil && disabled {
+ log.Infof("auto-update is disabled, stopping update manager")
+ e.updateManager.Stop()
+ e.updateManager = nil
+ return
+ }
+
+ // Skip check unless AlwaysUpdate is enabled or this is the initial check at startup
+ if !autoUpdateSettings.AlwaysUpdate && !initialCheck {
+ log.Debugf("skipping auto-update check, AlwaysUpdate is false and this is not the initial check")
+ return
+ }
+
+ // Start manager if needed
+ if e.updateManager == nil {
+ log.Infof("starting auto-update manager")
+ updateManager, err := updatemanager.NewManager(e.statusRecorder, e.stateManager)
+ if err != nil {
+ return
+ }
+ e.updateManager = updateManager
+ e.updateManager.Start(e.ctx)
+ }
+ log.Infof("handling auto-update version: %s", autoUpdateSettings.Version)
+ e.updateManager.SetVersion(autoUpdateSettings.Version)
+}
+
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
@@ -758,6 +781,10 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return e.ctx.Err()
}
+ if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
+ e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate, false)
+ }
+
if update.GetNetbirdConfig() != nil {
wCfg := update.GetNetbirdConfig()
err := e.updateTURNs(wCfg.GetTurns())
@@ -1094,6 +1121,15 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.updateOfflinePeers(networkMap.GetOfflinePeers())
+ // Filter out own peer from the remote peers list
+ localPubKey := e.config.WgPrivateKey.PublicKey().String()
+ remotePeers := make([]*mgmProto.RemotePeerConfig, 0, len(networkMap.GetRemotePeers()))
+ for _, p := range networkMap.GetRemotePeers() {
+ if p.GetWgPubKey() != localPubKey {
+ remotePeers = append(remotePeers, p)
+ }
+ }
+
// cleanup request, most likely our peer has been deleted
if networkMap.GetRemotePeersIsEmpty() {
err := e.removeAllPeers()
@@ -1102,32 +1138,34 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
return err
}
} else {
- err := e.removePeers(networkMap.GetRemotePeers())
+ err := e.removePeers(remotePeers)
if err != nil {
return err
}
- err = e.modifyPeers(networkMap.GetRemotePeers())
+ err = e.modifyPeers(remotePeers)
if err != nil {
return err
}
- err = e.addNewPeers(networkMap.GetRemotePeers())
+ err = e.addNewPeers(remotePeers)
if err != nil {
return err
}
e.statusRecorder.FinishPeerListModifications()
- e.updatePeerSSHHostKeys(networkMap.GetRemotePeers())
+ e.updatePeerSSHHostKeys(remotePeers)
- if err := e.updateSSHClientConfig(networkMap.GetRemotePeers()); err != nil {
+ if err := e.updateSSHClientConfig(remotePeers); err != nil {
log.Warnf("failed to update SSH client config: %v", err)
}
+
+ e.updateSSHServerAuth(networkMap.GetSshAuth())
}
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
- excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, networkMap.GetRemotePeers())
+ excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers)
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
e.networkSerial = serial
diff --git a/client/internal/engine_ssh.go b/client/internal/engine_ssh.go
index 861b3d6d2..e683d8cee 100644
--- a/client/internal/engine_ssh.go
+++ b/client/internal/engine_ssh.go
@@ -11,15 +11,18 @@ import (
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
+ sshauth "github.com/netbirdio/netbird/client/ssh/auth"
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
+ sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
type sshServer interface {
Start(ctx context.Context, addr netip.AddrPort) error
Stop() error
GetStatus() (bool, []sshserver.SessionInfo)
+ UpdateSSHAuth(config *sshauth.Config)
}
func (e *Engine) setupSSHPortRedirection() error {
@@ -353,3 +356,38 @@ func (e *Engine) GetSSHServerStatus() (enabled bool, sessions []sshserver.Sessio
return sshServer.GetStatus()
}
+
+// updateSSHServerAuth updates SSH fine-grained access control configuration on a running SSH server
+func (e *Engine) updateSSHServerAuth(sshAuth *mgmProto.SSHAuth) {
+ if sshAuth == nil {
+ return
+ }
+
+ if e.sshServer == nil {
+ return
+ }
+
+ protoUsers := sshAuth.GetAuthorizedUsers()
+ authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers))
+ for i, hash := range protoUsers {
+ if len(hash) != 16 {
+ log.Warnf("invalid hash length %d, expected 16 - skipping SSH server auth update", len(hash))
+ return
+ }
+ authorizedUsers[i] = sshuserhash.UserIDHash(hash)
+ }
+
+ machineUsers := make(map[string][]uint32)
+ for osUser, indexes := range sshAuth.GetMachineUsers() {
+ machineUsers[osUser] = indexes.GetIndexes()
+ }
+
+ // Update SSH server with new authorization configuration
+ authConfig := &sshauth.Config{
+ UserIDClaim: sshAuth.GetUserIDClaim(),
+ AuthorizedUsers: authorizedUsers,
+ MachineUsers: machineUsers,
+ }
+
+ e.sshServer.UpdateSSHAuth(authConfig)
+}
diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go
index 5ab21e3e1..a15ee0581 100644
--- a/client/internal/engine_test.go
+++ b/client/internal/engine_test.go
@@ -253,6 +253,7 @@ func TestEngine_SSH(t *testing.T) {
MobileDependency{},
peer.NewRecorder("https://mgm"),
nil,
+ nil,
)
engine.dnsServer = &dns.MockServer{
@@ -414,21 +415,13 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
- engine := NewEngine(
- ctx, cancel,
- &signal.MockClient{},
- &mgmt.MockClient{},
- relayMgr,
- &EngineConfig{
- WgIfaceName: "utun102",
- WgAddr: "100.64.0.1/24",
- WgPrivateKey: key,
- WgPort: 33100,
- MTU: iface.DefaultMTU,
- },
- MobileDependency{},
- peer.NewRecorder("https://mgm"),
- nil)
+ engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
+ WgIfaceName: "utun102",
+ WgAddr: "100.64.0.1/24",
+ WgPrivateKey: key,
+ WgPort: 33100,
+ MTU: iface.DefaultMTU,
+ }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
wgIface := &MockWGIface{
NameFunc: func() string { return "utun102" },
@@ -647,7 +640,7 @@ func TestEngine_Sync(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
- }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
+ }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
engine.ctx = ctx
engine.dnsServer = &dns.MockServer{
@@ -812,7 +805,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
- }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
+ }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
engine.ctx = ctx
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
@@ -1014,7 +1007,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
- }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
+ }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
engine.ctx = ctx
newNet, err := stdnet.NewNet(context.Background(), nil)
@@ -1540,7 +1533,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
- e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
+ e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil), nil
e.ctx = ctx
return e, err
}
@@ -1638,7 +1631,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
if err != nil {
return nil, "", err
}
- mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController)
+ mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
if err != nil {
return nil, "", err
}
diff --git a/client/internal/networkmonitor/check_change_darwin.go b/client/internal/networkmonitor/check_change_darwin.go
index ddc6e1736..cb5236070 100644
--- a/client/internal/networkmonitor/check_change_darwin.go
+++ b/client/internal/networkmonitor/check_change_darwin.go
@@ -110,7 +110,6 @@ func wakeUpListen(ctx context.Context) {
}
if newHash == initialHash {
- log.Tracef("no wakeup detected")
continue
}
diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go
index 426c31e1a..20a2eb342 100644
--- a/client/internal/peer/conn.go
+++ b/client/internal/peer/conn.go
@@ -148,13 +148,15 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
// be used.
func (conn *Conn) Open(engineCtx context.Context) error {
- conn.semaphore.Add(engineCtx)
+ if err := conn.semaphore.Add(engineCtx); err != nil {
+ return err
+ }
conn.mu.Lock()
defer conn.mu.Unlock()
if conn.opened {
- conn.semaphore.Done(engineCtx)
+ conn.semaphore.Done()
return nil
}
@@ -165,6 +167,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
if err != nil {
+ conn.semaphore.Done()
return err
}
conn.workerICE = workerICE
@@ -200,7 +203,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
defer conn.wg.Done()
conn.waitInitialRandomSleepTime(conn.ctx)
- conn.semaphore.Done(conn.ctx)
+ conn.semaphore.Done()
conn.guard.Start(conn.ctx, conn.onGuardEvent)
}()
diff --git a/client/internal/profilemanager/config.go b/client/internal/profilemanager/config.go
index 8f467a214..f2fda84e0 100644
--- a/client/internal/profilemanager/config.go
+++ b/client/internal/profilemanager/config.go
@@ -3,9 +3,11 @@ package profilemanager
import (
"context"
"crypto/tls"
+ "encoding/json"
"fmt"
"net/url"
"os"
+ "os/user"
"path/filepath"
"reflect"
"runtime"
@@ -165,19 +167,26 @@ func getConfigDir() (string, error) {
if ConfigDirOverride != "" {
return ConfigDirOverride, nil
}
- configDir, err := os.UserConfigDir()
+
+ base, err := baseConfigDir()
if err != nil {
return "", err
}
- configDir = filepath.Join(configDir, "netbird")
- if _, err := os.Stat(configDir); os.IsNotExist(err) {
- if err := os.MkdirAll(configDir, 0755); err != nil {
- return "", err
+ configDir := filepath.Join(base, "netbird")
+ if err := os.MkdirAll(configDir, 0o755); err != nil {
+ return "", err
+ }
+ return configDir, nil
+}
+
+func baseConfigDir() (string, error) {
+ if runtime.GOOS == "darwin" {
+ if u, err := user.Current(); err == nil && u.HomeDir != "" {
+ return filepath.Join(u.HomeDir, "Library", "Application Support"), nil
}
}
-
- return configDir, nil
+ return os.UserConfigDir()
}
func getConfigDirForUser(username string) (string, error) {
@@ -676,7 +685,7 @@ func update(input ConfigInput) (*Config, error) {
return config, nil
}
-// GetConfig read config file and return with Config. Errors out if it does not exist
+// GetConfig read config file and return with Config and if it was created. Errors out if it does not exist
func GetConfig(configPath string) (*Config, error) {
return readConfig(configPath, false)
}
@@ -812,3 +821,85 @@ func readConfig(configPath string, createIfMissing bool) (*Config, error) {
func WriteOutConfig(path string, config *Config) error {
return util.WriteJson(context.Background(), path, config)
}
+
+// DirectWriteOutConfig writes config directly without atomic temp file operations.
+// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox).
+func DirectWriteOutConfig(path string, config *Config) error {
+ return util.DirectWriteJson(context.Background(), path, config)
+}
+
+// DirectUpdateOrCreateConfig is like UpdateOrCreateConfig but uses direct (non-atomic) writes.
+// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox).
+func DirectUpdateOrCreateConfig(input ConfigInput) (*Config, error) {
+ if !fileExists(input.ConfigPath) {
+ log.Infof("generating new config %s", input.ConfigPath)
+ cfg, err := createNewConfig(input)
+ if err != nil {
+ return nil, err
+ }
+ err = util.DirectWriteJson(context.Background(), input.ConfigPath, cfg)
+ return cfg, err
+ }
+
+ if isPreSharedKeyHidden(input.PreSharedKey) {
+ input.PreSharedKey = nil
+ }
+
+ // Enforce permissions on existing config files (same as UpdateOrCreateConfig)
+ if err := util.EnforcePermission(input.ConfigPath); err != nil {
+ log.Errorf("failed to enforce permission on config file: %v", err)
+ }
+
+ return directUpdate(input)
+}
+
+func directUpdate(input ConfigInput) (*Config, error) {
+ config := &Config{}
+
+ if _, err := util.ReadJson(input.ConfigPath, config); err != nil {
+ return nil, err
+ }
+
+ updated, err := config.apply(input)
+ if err != nil {
+ return nil, err
+ }
+
+ if updated {
+ if err := util.DirectWriteJson(context.Background(), input.ConfigPath, config); err != nil {
+ return nil, err
+ }
+ }
+
+ return config, nil
+}
+
+// ConfigToJSON serializes a Config struct to a JSON string.
+// This is useful for exporting config to alternative storage mechanisms
+// (e.g., UserDefaults on tvOS where file writes are blocked).
+func ConfigToJSON(config *Config) (string, error) {
+ bs, err := json.MarshalIndent(config, "", " ")
+ if err != nil {
+ return "", err
+ }
+ return string(bs), nil
+}
+
+// ConfigFromJSON deserializes a JSON string to a Config struct.
+// This is useful for restoring config from alternative storage mechanisms.
+// After unmarshaling, defaults are applied to ensure the config is fully initialized.
+func ConfigFromJSON(jsonStr string) (*Config, error) {
+ config := &Config{}
+ err := json.Unmarshal([]byte(jsonStr), config)
+ if err != nil {
+ return nil, err
+ }
+
+ // Apply defaults to ensure required fields are initialized.
+ // This mirrors what readConfig does after loading from file.
+ if _, err := config.apply(ConfigInput{}); err != nil {
+ return nil, fmt.Errorf("failed to apply defaults to config: %w", err)
+ }
+
+ return config, nil
+}
diff --git a/client/internal/profilemanager/service.go b/client/internal/profilemanager/service.go
index faccf5f68..bdb722c67 100644
--- a/client/internal/profilemanager/service.go
+++ b/client/internal/profilemanager/service.go
@@ -76,6 +76,7 @@ func (a *ActiveProfileState) FilePath() (string, error) {
}
type ServiceManager struct {
+ profilesDir string // If set, overrides ConfigDirOverride for profile operations
}
func NewServiceManager(defaultConfigPath string) *ServiceManager {
@@ -85,6 +86,17 @@ func NewServiceManager(defaultConfigPath string) *ServiceManager {
return &ServiceManager{}
}
+// NewServiceManagerWithProfilesDir creates a ServiceManager with a specific profiles directory
+// This allows setting the profiles directory without modifying the global ConfigDirOverride
+func NewServiceManagerWithProfilesDir(defaultConfigPath string, profilesDir string) *ServiceManager {
+ if defaultConfigPath != "" {
+ DefaultConfigPath = defaultConfigPath
+ }
+ return &ServiceManager{
+ profilesDir: profilesDir,
+ }
+}
+
func (s *ServiceManager) CopyDefaultProfileIfNotExists() (bool, error) {
if err := os.MkdirAll(DefaultConfigPathDir, 0600); err != nil {
@@ -114,14 +126,6 @@ func (s *ServiceManager) CopyDefaultProfileIfNotExists() (bool, error) {
log.Warnf("failed to set permissions for default profile: %v", err)
}
- if err := s.SetActiveProfileState(&ActiveProfileState{
- Name: "default",
- Username: "",
- }); err != nil {
- log.Errorf("failed to set active profile state: %v", err)
- return false, fmt.Errorf("failed to set active profile state: %w", err)
- }
-
return true, nil
}
@@ -240,7 +244,7 @@ func (s *ServiceManager) DefaultProfilePath() string {
}
func (s *ServiceManager) AddProfile(profileName, username string) error {
- configDir, err := getConfigDirForUser(username)
+ configDir, err := s.getConfigDir(username)
if err != nil {
return fmt.Errorf("failed to get config directory: %w", err)
}
@@ -270,7 +274,7 @@ func (s *ServiceManager) AddProfile(profileName, username string) error {
}
func (s *ServiceManager) RemoveProfile(profileName, username string) error {
- configDir, err := getConfigDirForUser(username)
+ configDir, err := s.getConfigDir(username)
if err != nil {
return fmt.Errorf("failed to get config directory: %w", err)
}
@@ -302,7 +306,7 @@ func (s *ServiceManager) RemoveProfile(profileName, username string) error {
}
func (s *ServiceManager) ListProfiles(username string) ([]Profile, error) {
- configDir, err := getConfigDirForUser(username)
+ configDir, err := s.getConfigDir(username)
if err != nil {
return nil, fmt.Errorf("failed to get config directory: %w", err)
}
@@ -361,7 +365,7 @@ func (s *ServiceManager) GetStatePath() string {
return defaultStatePath
}
- configDir, err := getConfigDirForUser(activeProf.Username)
+ configDir, err := s.getConfigDir(activeProf.Username)
if err != nil {
log.Warnf("failed to get config directory for user %s: %v", activeProf.Username, err)
return defaultStatePath
@@ -369,3 +373,12 @@ func (s *ServiceManager) GetStatePath() string {
return filepath.Join(configDir, activeProf.Name+".state.json")
}
+
+// getConfigDir returns the profiles directory, using profilesDir if set, otherwise getConfigDirForUser
+func (s *ServiceManager) getConfigDir(username string) (string, error) {
+ if s.profilesDir != "" {
+ return s.profilesDir, nil
+ }
+
+ return getConfigDirForUser(username)
+}
diff --git a/client/internal/updatemanager/doc.go b/client/internal/updatemanager/doc.go
new file mode 100644
index 000000000..54d1bdeab
--- /dev/null
+++ b/client/internal/updatemanager/doc.go
@@ -0,0 +1,35 @@
+// Package updatemanager provides automatic update management for the NetBird client.
+// It monitors for new versions, handles update triggers from management server directives,
+// and orchestrates the download and installation of client updates.
+//
+// # Overview
+//
+// The update manager operates as a background service that continuously monitors for
+// available updates and automatically initiates the update process when conditions are met.
+// It integrates with the installer package to perform the actual installation.
+//
+// # Update Flow
+//
+// The complete update process follows these steps:
+//
+// 1. Manager receives update directive via SetVersion() or detects new version
+// 2. Manager validates update should proceed (version comparison, rate limiting)
+// 3. Manager publishes "updating" event to status recorder
+// 4. Manager persists UpdateState to track update attempt
+// 5. Manager downloads installer file (.msi or .exe) to temporary directory
+// 6. Manager triggers installation via installer.RunInstallation()
+// 7. Installer package handles the actual installation process
+// 8. On next startup, CheckUpdateSuccess() verifies update completion
+// 9. Manager publishes success/failure event to status recorder
+// 10. Manager cleans up UpdateState
+//
+// # State Management
+//
+// Update state is persisted across restarts to track update attempts:
+//
+// - PreUpdateVersion: Version before update attempt
+// - TargetVersion: Version attempting to update to
+//
+// This enables verification of successful updates and appropriate user notification
+// after the client restarts with the new version.
+package updatemanager
diff --git a/client/internal/updatemanager/downloader/downloader.go b/client/internal/updatemanager/downloader/downloader.go
new file mode 100644
index 000000000..2ac36efed
--- /dev/null
+++ b/client/internal/updatemanager/downloader/downloader.go
@@ -0,0 +1,138 @@
+package downloader
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/version"
+)
+
+const (
+ userAgent = "NetBird agent installer/%s"
+ DefaultRetryDelay = 3 * time.Second
+)
+
+func DownloadToFile(ctx context.Context, retryDelay time.Duration, url, dstFile string) error {
+ log.Debugf("starting download from %s", url)
+
+ out, err := os.Create(dstFile)
+ if err != nil {
+ return fmt.Errorf("failed to create destination file %q: %w", dstFile, err)
+ }
+ defer func() {
+ if cerr := out.Close(); cerr != nil {
+ log.Warnf("error closing file %q: %v", dstFile, cerr)
+ }
+ }()
+
+ // First attempt
+ err = downloadToFileOnce(ctx, url, out)
+ if err == nil {
+ log.Infof("successfully downloaded file to %s", dstFile)
+ return nil
+ }
+
+ // If retryDelay is 0, don't retry
+ if retryDelay == 0 {
+ return err
+ }
+
+ log.Warnf("download failed, retrying after %v: %v", retryDelay, err)
+
+ // Sleep before retry
+ if sleepErr := sleepWithContext(ctx, retryDelay); sleepErr != nil {
+ return fmt.Errorf("download cancelled during retry delay: %w", sleepErr)
+ }
+
+ // Truncate file before retry
+ if err := out.Truncate(0); err != nil {
+ return fmt.Errorf("failed to truncate file on retry: %w", err)
+ }
+ if _, err := out.Seek(0, 0); err != nil {
+ return fmt.Errorf("failed to seek to beginning of file: %w", err)
+ }
+
+ // Second attempt
+ if err := downloadToFileOnce(ctx, url, out); err != nil {
+ return fmt.Errorf("download failed after retry: %w", err)
+ }
+
+ log.Infof("successfully downloaded file to %s", dstFile)
+ return nil
+}
+
+func DownloadToMemory(ctx context.Context, url string, limit int64) ([]byte, error) {
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create HTTP request: %w", err)
+ }
+
+ // Add User-Agent header
+ req.Header.Set("User-Agent", fmt.Sprintf(userAgent, version.NetbirdVersion()))
+
+ resp, err := http.DefaultClient.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("failed to perform HTTP request: %w", err)
+ }
+ defer func() {
+ if cerr := resp.Body.Close(); cerr != nil {
+ log.Warnf("error closing response body: %v", cerr)
+ }
+ }()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("unexpected HTTP status: %d", resp.StatusCode)
+ }
+
+ data, err := io.ReadAll(io.LimitReader(resp.Body, limit))
+ if err != nil {
+ return nil, fmt.Errorf("failed to read response body: %w", err)
+ }
+
+ return data, nil
+}
+
+func downloadToFileOnce(ctx context.Context, url string, out *os.File) error {
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
+ if err != nil {
+ return fmt.Errorf("failed to create HTTP request: %w", err)
+ }
+
+ // Add User-Agent header
+ req.Header.Set("User-Agent", fmt.Sprintf(userAgent, version.NetbirdVersion()))
+
+ resp, err := http.DefaultClient.Do(req)
+ if err != nil {
+ return fmt.Errorf("failed to perform HTTP request: %w", err)
+ }
+ defer func() {
+ if cerr := resp.Body.Close(); cerr != nil {
+ log.Warnf("error closing response body: %v", cerr)
+ }
+ }()
+
+ if resp.StatusCode != http.StatusOK {
+ return fmt.Errorf("unexpected HTTP status: %d", resp.StatusCode)
+ }
+
+ if _, err := io.Copy(out, resp.Body); err != nil {
+ return fmt.Errorf("failed to write response body to file: %w", err)
+ }
+
+ return nil
+}
+
+func sleepWithContext(ctx context.Context, duration time.Duration) error {
+ select {
+ case <-time.After(duration):
+ return nil
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+}
diff --git a/client/internal/updatemanager/downloader/downloader_test.go b/client/internal/updatemanager/downloader/downloader_test.go
new file mode 100644
index 000000000..045db3a2d
--- /dev/null
+++ b/client/internal/updatemanager/downloader/downloader_test.go
@@ -0,0 +1,199 @@
+package downloader
+
+import (
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+const (
+ retryDelay = 100 * time.Millisecond
+)
+
+func TestDownloadToFile_Success(t *testing.T) {
+ // Create a test server that responds successfully
+ content := "test file content"
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write([]byte(content))
+ }))
+ defer server.Close()
+
+ // Create a temporary file for download
+ tempDir := t.TempDir()
+ dstFile := filepath.Join(tempDir, "downloaded.txt")
+
+ // Download the file
+ err := DownloadToFile(context.Background(), retryDelay, server.URL, dstFile)
+ if err != nil {
+ t.Fatalf("expected no error, got: %v", err)
+ }
+
+ // Verify the file content
+ data, err := os.ReadFile(dstFile)
+ if err != nil {
+ t.Fatalf("failed to read downloaded file: %v", err)
+ }
+
+ if string(data) != content {
+ t.Errorf("expected content %q, got %q", content, string(data))
+ }
+}
+
+func TestDownloadToFile_SuccessAfterRetry(t *testing.T) {
+ content := "test file content after retry"
+ var attemptCount atomic.Int32
+
+ // Create a test server that fails on first attempt, succeeds on second
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ attempt := attemptCount.Add(1)
+ if attempt == 1 {
+ w.WriteHeader(http.StatusInternalServerError)
+ _, _ = w.Write([]byte("error"))
+ return
+ }
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write([]byte(content))
+ }))
+ defer server.Close()
+
+ // Create a temporary file for download
+ tempDir := t.TempDir()
+ dstFile := filepath.Join(tempDir, "downloaded.txt")
+
+ // Download the file (should succeed after retry)
+ if err := DownloadToFile(context.Background(), 10*time.Millisecond, server.URL, dstFile); err != nil {
+ t.Fatalf("expected no error after retry, got: %v", err)
+ }
+
+ // Verify the file content
+ data, err := os.ReadFile(dstFile)
+ if err != nil {
+ t.Fatalf("failed to read downloaded file: %v", err)
+ }
+
+ if string(data) != content {
+ t.Errorf("expected content %q, got %q", content, string(data))
+ }
+
+ // Verify it took 2 attempts
+ if attemptCount.Load() != 2 {
+ t.Errorf("expected 2 attempts, got %d", attemptCount.Load())
+ }
+}
+
+func TestDownloadToFile_FailsAfterRetry(t *testing.T) {
+ var attemptCount atomic.Int32
+
+ // Create a test server that always fails
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ attemptCount.Add(1)
+ w.WriteHeader(http.StatusInternalServerError)
+ _, _ = w.Write([]byte("error"))
+ }))
+ defer server.Close()
+
+ // Create a temporary file for download
+ tempDir := t.TempDir()
+ dstFile := filepath.Join(tempDir, "downloaded.txt")
+
+ // Download the file (should fail after retry)
+ if err := DownloadToFile(context.Background(), 10*time.Millisecond, server.URL, dstFile); err == nil {
+ t.Fatal("expected error after retry, got nil")
+ }
+
+ // Verify it tried 2 times
+ if attemptCount.Load() != 2 {
+ t.Errorf("expected 2 attempts, got %d", attemptCount.Load())
+ }
+}
+
+func TestDownloadToFile_ContextCancellationDuringRetry(t *testing.T) {
+ var attemptCount atomic.Int32
+
+ // Create a test server that always fails
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ attemptCount.Add(1)
+ w.WriteHeader(http.StatusInternalServerError)
+ }))
+ defer server.Close()
+
+ // Create a temporary file for download
+ tempDir := t.TempDir()
+ dstFile := filepath.Join(tempDir, "downloaded.txt")
+
+ // Create a context that will be cancelled during retry delay
+ ctx, cancel := context.WithCancel(context.Background())
+
+ // Cancel after a short delay (during the retry sleep)
+ go func() {
+ time.Sleep(100 * time.Millisecond)
+ cancel()
+ }()
+
+ // Download the file (should fail due to context cancellation during retry)
+ err := DownloadToFile(ctx, 1*time.Second, server.URL, dstFile)
+ if err == nil {
+ t.Fatal("expected error due to context cancellation, got nil")
+ }
+
+ // Should have only made 1 attempt (cancelled during retry delay)
+ if attemptCount.Load() != 1 {
+ t.Errorf("expected 1 attempt, got %d", attemptCount.Load())
+ }
+}
+
+func TestDownloadToFile_InvalidURL(t *testing.T) {
+ tempDir := t.TempDir()
+ dstFile := filepath.Join(tempDir, "downloaded.txt")
+
+ err := DownloadToFile(context.Background(), retryDelay, "://invalid-url", dstFile)
+ if err == nil {
+ t.Fatal("expected error for invalid URL, got nil")
+ }
+}
+
+func TestDownloadToFile_InvalidDestination(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write([]byte("test"))
+ }))
+ defer server.Close()
+
+ // Use an invalid destination path
+ err := DownloadToFile(context.Background(), retryDelay, server.URL, "/invalid/path/that/does/not/exist/file.txt")
+ if err == nil {
+ t.Fatal("expected error for invalid destination, got nil")
+ }
+}
+
+func TestDownloadToFile_NoRetry(t *testing.T) {
+ var attemptCount atomic.Int32
+
+ // Create a test server that always fails
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ attemptCount.Add(1)
+ w.WriteHeader(http.StatusInternalServerError)
+ _, _ = w.Write([]byte("error"))
+ }))
+ defer server.Close()
+
+ // Create a temporary file for download
+ tempDir := t.TempDir()
+ dstFile := filepath.Join(tempDir, "downloaded.txt")
+
+ // Download the file with retryDelay = 0 (should not retry)
+ if err := DownloadToFile(context.Background(), 0, server.URL, dstFile); err == nil {
+ t.Fatal("expected error, got nil")
+ }
+
+ // Verify it only made 1 attempt (no retry)
+ if attemptCount.Load() != 1 {
+ t.Errorf("expected 1 attempt, got %d", attemptCount.Load())
+ }
+}
diff --git a/client/internal/updatemanager/installer/binary_nowindows.go b/client/internal/updatemanager/installer/binary_nowindows.go
new file mode 100644
index 000000000..19f3bef83
--- /dev/null
+++ b/client/internal/updatemanager/installer/binary_nowindows.go
@@ -0,0 +1,7 @@
+//go:build !windows
+
+package installer
+
+func UpdaterBinaryNameWithoutExtension() string {
+ return updaterBinary
+}
diff --git a/client/internal/updatemanager/installer/binary_windows.go b/client/internal/updatemanager/installer/binary_windows.go
new file mode 100644
index 000000000..4c66391c2
--- /dev/null
+++ b/client/internal/updatemanager/installer/binary_windows.go
@@ -0,0 +1,11 @@
+package installer
+
+import (
+ "path/filepath"
+ "strings"
+)
+
+func UpdaterBinaryNameWithoutExtension() string {
+ ext := filepath.Ext(updaterBinary)
+ return strings.TrimSuffix(updaterBinary, ext)
+}
diff --git a/client/internal/updatemanager/installer/doc.go b/client/internal/updatemanager/installer/doc.go
new file mode 100644
index 000000000..0a60454bb
--- /dev/null
+++ b/client/internal/updatemanager/installer/doc.go
@@ -0,0 +1,111 @@
+// Package installer provides functionality for managing NetBird application
+// updates and installations across Windows, macOS. It handles
+// the complete update lifecycle including artifact download, cryptographic verification,
+// installation execution, process management, and result reporting.
+//
+// # Architecture
+//
+// The installer package uses a two-process architecture to enable self-updates:
+//
+// 1. Service Process: The main NetBird daemon process that initiates updates
+// 2. Updater Process: A detached child process that performs the actual installation
+//
+// This separation is critical because:
+// - The service binary cannot update itself while running
+// - The installer (EXE/MSI/PKG) will terminate the service during installation
+// - The updater process survives service termination and restarts it after installation
+// - Results can be communicated back to the service after it restarts
+//
+// # Update Flow
+//
+// Service Process (RunInstallation):
+//
+// 1. Validates target version format (semver)
+// 2. Determines installer type (EXE, MSI, PKG, or Homebrew)
+// 3. Downloads installer file from GitHub releases (if applicable)
+// 4. Verifies installer signature using reposign package (cryptographic verification in service process before
+// launching updater)
+// 5. Copies service binary to tempDir as "updater" (or "updater.exe" on Windows)
+// 6. Launches updater process with detached mode:
+// - --temp-dir: Temporary directory path
+// - --service-dir: Service installation directory
+// - --installer-file: Path to downloaded installer (if applicable)
+// - --dry-run: Optional flag to test without actually installing
+// 7. Service process continues running (will be terminated by installer later)
+// 8. Service can watch for result.json using ResultHandler.Watch() to detect completion
+//
+// Updater Process (Setup):
+//
+// 1. Receives parameters from service via command-line arguments
+// 2. Runs installer with appropriate silent/quiet flags:
+// - Windows EXE: installer.exe /S
+// - Windows MSI: msiexec.exe /i installer.msi /quiet /qn /l*v msi.log
+// - macOS PKG: installer -pkg installer.pkg -target /
+// - macOS Homebrew: brew upgrade netbirdio/tap/netbird
+// 3. Installer terminates daemon and UI processes
+// 4. Installer replaces binaries with new version
+// 5. Updater waits for installer to complete
+// 6. Updater restarts daemon:
+// - Windows: netbird.exe service start
+// - macOS/Linux: netbird service start
+// 7. Updater restarts UI:
+// - Windows: Launches netbird-ui.exe as active console user using CreateProcessAsUser
+// - macOS: Uses launchctl asuser to launch NetBird.app for console user
+// - Linux: Not implemented (UI typically auto-starts)
+// 8. Updater writes result.json with success/error status
+// 9. Updater process exits
+//
+// # Result Communication
+//
+// The ResultHandler (result.go) manages communication between updater and service:
+//
+// Result Structure:
+//
+// type Result struct {
+// Success bool // true if installation succeeded
+// Error string // error message if Success is false
+// ExecutedAt time.Time // when installation completed
+// }
+//
+// Result files are automatically cleaned up after being read.
+//
+// # File Locations
+//
+// Temporary Directory (platform-specific):
+//
+// Windows:
+// - Path: %ProgramData%\Netbird\tmp-install
+// - Example: C:\ProgramData\Netbird\tmp-install
+//
+// macOS:
+// - Path: /var/lib/netbird/tmp-install
+// - Requires root permissions
+//
+// Files created during installation:
+//
+// tmp-install/
+// installer.log
+// updater[.exe] # Copy of service binary
+// netbird_installer_*.[exe|msi|pkg] # Downloaded installer
+// result.json # Installation result
+// msi.log # MSI verbose log (Windows MSI only)
+//
+// # API Reference
+//
+// # Cleanup
+//
+// CleanUpInstallerFiles() removes temporary files after successful installation:
+// - Downloaded installer files (*.exe, *.msi, *.pkg)
+// - Updater binary copy
+// - Does NOT remove result.json (cleaned by ResultHandler after read)
+// - Does NOT remove msi.log (kept for debugging)
+//
+// # Dry-Run Mode
+//
+// Dry-run mode allows testing the update process without actually installing:
+//
+// Enable via environment variable:
+//
+// export NB_AUTO_UPDATE_DRY_RUN=true
+// netbird service install-update 0.29.0
+package installer
diff --git a/client/internal/updatemanager/installer/installer.go b/client/internal/updatemanager/installer/installer.go
new file mode 100644
index 000000000..caf5873f8
--- /dev/null
+++ b/client/internal/updatemanager/installer/installer.go
@@ -0,0 +1,50 @@
+//go:build !windows && !darwin
+
+package installer
+
+import (
+ "context"
+ "fmt"
+)
+
+const (
+ updaterBinary = "updater"
+)
+
+type Installer struct {
+ tempDir string
+}
+
+// New used by the service
+func New() *Installer {
+ return &Installer{}
+}
+
+// NewWithDir used by the updater process, get the tempDir from the service via cmd line
+func NewWithDir(tempDir string) *Installer {
+ return &Installer{
+ tempDir: tempDir,
+ }
+}
+
+func (u *Installer) TempDir() string {
+ return ""
+}
+
+func (c *Installer) LogFiles() []string {
+ return []string{}
+}
+
+func (u *Installer) CleanUpInstallerFiles() error {
+ return nil
+}
+
+func (u *Installer) RunInstallation(ctx context.Context, targetVersion string) error {
+ return fmt.Errorf("unsupported platform")
+}
+
+// Setup runs the installer with appropriate arguments and manages the daemon/UI state
+// This will be run by the updater process
+func (u *Installer) Setup(ctx context.Context, dryRun bool, targetVersion string, daemonFolder string) (resultErr error) {
+ return fmt.Errorf("unsupported platform")
+}
diff --git a/client/internal/updatemanager/installer/installer_common.go b/client/internal/updatemanager/installer/installer_common.go
new file mode 100644
index 000000000..03378d55f
--- /dev/null
+++ b/client/internal/updatemanager/installer/installer_common.go
@@ -0,0 +1,293 @@
+//go:build windows || darwin
+
+package installer
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "os"
+ "os/exec"
+ "path"
+ "path/filepath"
+ "strings"
+
+ "github.com/hashicorp/go-multierror"
+ goversion "github.com/hashicorp/go-version"
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/client/internal/updatemanager/downloader"
+ "github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
+)
+
+type Installer struct {
+ tempDir string
+}
+
+// New used by the service
+func New() *Installer {
+ return &Installer{
+ tempDir: defaultTempDir,
+ }
+}
+
+// NewWithDir used by the updater process, get the tempDir from the service via cmd line
+func NewWithDir(tempDir string) *Installer {
+ return &Installer{
+ tempDir: tempDir,
+ }
+}
+
+// RunInstallation starts the updater process to run the installation
+// This will run by the original service process
+func (u *Installer) RunInstallation(ctx context.Context, targetVersion string) (err error) {
+ resultHandler := NewResultHandler(u.tempDir)
+
+ defer func() {
+ if err != nil {
+ if writeErr := resultHandler.WriteErr(err); writeErr != nil {
+ log.Errorf("failed to write error result: %v", writeErr)
+ }
+ }
+ }()
+
+ if err := validateTargetVersion(targetVersion); err != nil {
+ return err
+ }
+
+ if err := u.mkTempDir(); err != nil {
+ return err
+ }
+
+ var installerFile string
+ // Download files only when not using any third-party store
+ if installerType := TypeOfInstaller(ctx); installerType.Downloadable() {
+ log.Infof("download installer")
+ var err error
+ installerFile, err = u.downloadInstaller(ctx, installerType, targetVersion)
+ if err != nil {
+ log.Errorf("failed to download installer: %v", err)
+ return err
+ }
+
+ artifactVerify, err := reposign.NewArtifactVerify(DefaultSigningKeysBaseURL)
+ if err != nil {
+ log.Errorf("failed to create artifact verify: %v", err)
+ return err
+ }
+
+ if err := artifactVerify.Verify(ctx, targetVersion, installerFile); err != nil {
+ log.Errorf("artifact verification error: %v", err)
+ return err
+ }
+ }
+
+ log.Infof("running installer")
+ updaterPath, err := u.copyUpdater()
+ if err != nil {
+ return err
+ }
+
+ // the directory where the service has been installed
+ workspace, err := getServiceDir()
+ if err != nil {
+ return err
+ }
+
+ args := []string{
+ "--temp-dir", u.tempDir,
+ "--service-dir", workspace,
+ }
+
+ if isDryRunEnabled() {
+ args = append(args, "--dry-run=true")
+ }
+
+ if installerFile != "" {
+ args = append(args, "--installer-file", installerFile)
+ }
+
+ updateCmd := exec.Command(updaterPath, args...)
+ log.Infof("starting updater process: %s", updateCmd.String())
+
+ // Configure the updater to run in a separate session/process group
+ // so it survives the parent daemon being stopped
+ setUpdaterProcAttr(updateCmd)
+
+ // Start the updater process asynchronously
+ if err := updateCmd.Start(); err != nil {
+ return err
+ }
+
+ pid := updateCmd.Process.Pid
+ log.Infof("updater started with PID %d", pid)
+
+ // Release the process so the OS can fully detach it
+ if err := updateCmd.Process.Release(); err != nil {
+ log.Warnf("failed to release updater process: %v", err)
+ }
+
+ return nil
+}
+
+// CleanUpInstallerFiles
+// - the installer file (pkg, exe, msi)
+// - the selfcopy updater.exe
+func (u *Installer) CleanUpInstallerFiles() error {
+ // Check if tempDir exists
+ info, err := os.Stat(u.tempDir)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return nil
+ }
+ return err
+ }
+
+ if !info.IsDir() {
+ return nil
+ }
+
+ var merr *multierror.Error
+
+ if err := os.Remove(filepath.Join(u.tempDir, updaterBinary)); err != nil && !os.IsNotExist(err) {
+ merr = multierror.Append(merr, fmt.Errorf("failed to remove updater binary: %w", err))
+ }
+
+ entries, err := os.ReadDir(u.tempDir)
+ if err != nil {
+ return err
+ }
+
+ for _, entry := range entries {
+ if entry.IsDir() {
+ continue
+ }
+
+ name := entry.Name()
+ for _, ext := range binaryExtensions {
+ if strings.HasSuffix(strings.ToLower(name), strings.ToLower(ext)) {
+ if err := os.Remove(filepath.Join(u.tempDir, name)); err != nil {
+ merr = multierror.Append(merr, fmt.Errorf("failed to remove %s: %w", name, err))
+ }
+ break
+ }
+ }
+ }
+
+ return merr.ErrorOrNil()
+}
+
+func (u *Installer) downloadInstaller(ctx context.Context, installerType Type, targetVersion string) (string, error) {
+ fileURL := urlWithVersionArch(installerType, targetVersion)
+
+ // Clean up temp directory on error
+ var success bool
+ defer func() {
+ if !success {
+ if err := os.RemoveAll(u.tempDir); err != nil {
+ log.Errorf("error cleaning up temporary directory: %v", err)
+ }
+ }
+ }()
+
+ fileName := path.Base(fileURL)
+ if fileName == "." || fileName == "/" || fileName == "" {
+ return "", fmt.Errorf("invalid file URL: %s", fileURL)
+ }
+
+ outputFilePath := filepath.Join(u.tempDir, fileName)
+ if err := downloader.DownloadToFile(ctx, downloader.DefaultRetryDelay, fileURL, outputFilePath); err != nil {
+ return "", err
+ }
+
+ success = true
+ return outputFilePath, nil
+}
+
+func (u *Installer) TempDir() string {
+ return u.tempDir
+}
+
+func (u *Installer) mkTempDir() error {
+ if err := os.MkdirAll(u.tempDir, 0o755); err != nil {
+ log.Debugf("failed to create tempdir: %s", u.tempDir)
+ return err
+ }
+ return nil
+}
+
+func (u *Installer) copyUpdater() (string, error) {
+ src, err := getServiceBinary()
+ if err != nil {
+ return "", fmt.Errorf("failed to get updater binary: %w", err)
+ }
+
+ dst := filepath.Join(u.tempDir, updaterBinary)
+ if err := copyFile(src, dst); err != nil {
+ return "", fmt.Errorf("failed to copy updater binary: %w", err)
+ }
+
+ if err := os.Chmod(dst, 0o755); err != nil {
+ return "", fmt.Errorf("failed to set permissions: %w", err)
+ }
+
+ return dst, nil
+}
+
+func validateTargetVersion(targetVersion string) error {
+ if targetVersion == "" {
+ return fmt.Errorf("target version cannot be empty")
+ }
+
+ _, err := goversion.NewVersion(targetVersion)
+ if err != nil {
+ return fmt.Errorf("invalid target version %q: %w", targetVersion, err)
+ }
+
+ return nil
+}
+
+func copyFile(src, dst string) error {
+ log.Infof("copying %s to %s", src, dst)
+ in, err := os.Open(src)
+ if err != nil {
+ return fmt.Errorf("open source: %w", err)
+ }
+ defer func() {
+ if err := in.Close(); err != nil {
+ log.Warnf("failed to close source file: %v", err)
+ }
+ }()
+
+ out, err := os.Create(dst)
+ if err != nil {
+ return fmt.Errorf("create destination: %w", err)
+ }
+ defer func() {
+ if err := out.Close(); err != nil {
+ log.Warnf("failed to close destination file: %v", err)
+ }
+ }()
+
+ if _, err := io.Copy(out, in); err != nil {
+ return fmt.Errorf("copy: %w", err)
+ }
+
+ return nil
+}
+
+func getServiceDir() (string, error) {
+ exePath, err := os.Executable()
+ if err != nil {
+ return "", err
+ }
+ return filepath.Dir(exePath), nil
+}
+
+func getServiceBinary() (string, error) {
+ return os.Executable()
+}
+
+func isDryRunEnabled() bool {
+ return strings.EqualFold(strings.TrimSpace(os.Getenv("NB_AUTO_UPDATE_DRY_RUN")), "true")
+}
diff --git a/client/internal/updatemanager/installer/installer_log_darwin.go b/client/internal/updatemanager/installer/installer_log_darwin.go
new file mode 100644
index 000000000..50dd5d197
--- /dev/null
+++ b/client/internal/updatemanager/installer/installer_log_darwin.go
@@ -0,0 +1,11 @@
+package installer
+
+import (
+ "path/filepath"
+)
+
+func (u *Installer) LogFiles() []string {
+ return []string{
+ filepath.Join(u.tempDir, LogFile),
+ }
+}
diff --git a/client/internal/updatemanager/installer/installer_log_windows.go b/client/internal/updatemanager/installer/installer_log_windows.go
new file mode 100644
index 000000000..96e4cfd1f
--- /dev/null
+++ b/client/internal/updatemanager/installer/installer_log_windows.go
@@ -0,0 +1,12 @@
+package installer
+
+import (
+ "path/filepath"
+)
+
+func (u *Installer) LogFiles() []string {
+ return []string{
+ filepath.Join(u.tempDir, msiLogFile),
+ filepath.Join(u.tempDir, LogFile),
+ }
+}
diff --git a/client/internal/updatemanager/installer/installer_run_darwin.go b/client/internal/updatemanager/installer/installer_run_darwin.go
new file mode 100644
index 000000000..248a404aa
--- /dev/null
+++ b/client/internal/updatemanager/installer/installer_run_darwin.go
@@ -0,0 +1,238 @@
+package installer
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "os/exec"
+ "os/user"
+ "path/filepath"
+ "runtime"
+ "strings"
+ "syscall"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+)
+
+const (
+ daemonName = "netbird"
+ updaterBinary = "updater"
+ uiBinary = "/Applications/NetBird.app"
+
+ defaultTempDir = "/var/lib/netbird/tmp-install"
+
+ pkgDownloadURL = "https://github.com/netbirdio/netbird/releases/download/v%version/netbird_%version_darwin_%arch.pkg"
+)
+
+var (
+ binaryExtensions = []string{"pkg"}
+)
+
+// Setup runs the installer with appropriate arguments and manages the daemon/UI state
+// This will be run by the updater process
+func (u *Installer) Setup(ctx context.Context, dryRun bool, installerFile string, daemonFolder string) (resultErr error) {
+ resultHandler := NewResultHandler(u.tempDir)
+
+ // Always ensure daemon and UI are restarted after setup
+ defer func() {
+ log.Infof("write out result")
+ var err error
+ if resultErr == nil {
+ err = resultHandler.WriteSuccess()
+ } else {
+ err = resultHandler.WriteErr(resultErr)
+ }
+ if err != nil {
+ log.Errorf("failed to write update result: %v", err)
+ }
+
+ // skip service restart if dry-run mode is enabled
+ if dryRun {
+ return
+ }
+
+ log.Infof("starting daemon back")
+ if err := u.startDaemon(daemonFolder); err != nil {
+ log.Errorf("failed to start daemon: %v", err)
+ }
+
+ log.Infof("starting UI back")
+ if err := u.startUIAsUser(); err != nil {
+ log.Errorf("failed to start UI: %v", err)
+ }
+
+ }()
+
+ if dryRun {
+ time.Sleep(7 * time.Second)
+ log.Infof("dry-run mode enabled, skipping actual installation")
+ resultErr = fmt.Errorf("dry-run mode enabled")
+ return
+ }
+
+ switch TypeOfInstaller(ctx) {
+ case TypePKG:
+ resultErr = u.installPkgFile(ctx, installerFile)
+ case TypeHomebrew:
+ resultErr = u.updateHomeBrew(ctx)
+ }
+
+ return resultErr
+}
+
+func (u *Installer) startDaemon(daemonFolder string) error {
+ log.Infof("starting netbird service")
+ ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
+ defer cancel()
+
+ cmd := exec.CommandContext(ctx, filepath.Join(daemonFolder, daemonName), "service", "start")
+ if output, err := cmd.CombinedOutput(); err != nil {
+ log.Warnf("failed to start netbird service: %v, output: %s", err, string(output))
+ return err
+ }
+ log.Infof("netbird service started successfully")
+ return nil
+}
+
+func (u *Installer) startUIAsUser() error {
+ log.Infof("starting netbird-ui: %s", uiBinary)
+
+ // Get the current console user
+ cmd := exec.Command("stat", "-f", "%Su", "/dev/console")
+ output, err := cmd.Output()
+ if err != nil {
+ return fmt.Errorf("failed to get console user: %w", err)
+ }
+
+ username := strings.TrimSpace(string(output))
+ if username == "" || username == "root" {
+ return fmt.Errorf("no active user session found")
+ }
+
+ log.Infof("starting UI for user: %s", username)
+
+ // Get user's UID
+ userInfo, err := user.Lookup(username)
+ if err != nil {
+ return fmt.Errorf("failed to lookup user %s: %w", username, err)
+ }
+
+ // Start the UI process as the console user using launchctl
+ // This ensures the app runs in the user's context with proper GUI access
+ launchCmd := exec.Command("launchctl", "asuser", userInfo.Uid, "open", "-a", uiBinary)
+ log.Infof("launchCmd: %s", launchCmd.String())
+ // Set the user's home directory for proper macOS app behavior
+ launchCmd.Env = append(os.Environ(), "HOME="+userInfo.HomeDir)
+ log.Infof("set HOME environment variable: %s", userInfo.HomeDir)
+
+ if err := launchCmd.Start(); err != nil {
+ return fmt.Errorf("failed to start UI process: %w", err)
+ }
+
+ // Release the process so it can run independently
+ if err := launchCmd.Process.Release(); err != nil {
+ log.Warnf("failed to release UI process: %v", err)
+ }
+
+ log.Infof("netbird-ui started successfully for user %s", username)
+ return nil
+}
+
+func (u *Installer) installPkgFile(ctx context.Context, path string) error {
+ log.Infof("installing pkg file: %s", path)
+
+ // Kill any existing UI processes before installation
+ // This ensures the postinstall script's "open $APP" will start the new version
+ u.killUI()
+
+ volume := "/"
+
+ cmd := exec.CommandContext(ctx, "installer", "-pkg", path, "-target", volume)
+ if err := cmd.Start(); err != nil {
+ return fmt.Errorf("error running pkg file: %w", err)
+ }
+ log.Infof("installer started with PID %d", cmd.Process.Pid)
+ if err := cmd.Wait(); err != nil {
+ return fmt.Errorf("error running pkg file: %w", err)
+ }
+ log.Infof("pkg file installed successfully")
+ return nil
+}
+
+func (u *Installer) updateHomeBrew(ctx context.Context) error {
+ log.Infof("updating homebrew")
+
+ // Kill any existing UI processes before upgrade
+ // This ensures the new version will be started after upgrade
+ u.killUI()
+
+ // Homebrew must be run as a non-root user
+ // To find out which user installed NetBird using HomeBrew we can check the owner of our brew tap directory
+ // Check both Apple Silicon and Intel Mac paths
+ brewTapPath := "/opt/homebrew/Library/Taps/netbirdio/homebrew-tap/"
+ brewBinPath := "/opt/homebrew/bin/brew"
+ if _, err := os.Stat(brewTapPath); os.IsNotExist(err) {
+ // Try Intel Mac path
+ brewTapPath = "/usr/local/Homebrew/Library/Taps/netbirdio/homebrew-tap/"
+ brewBinPath = "/usr/local/bin/brew"
+ }
+
+ fileInfo, err := os.Stat(brewTapPath)
+ if err != nil {
+ return fmt.Errorf("error getting homebrew installation path info: %w", err)
+ }
+
+ fileSysInfo, ok := fileInfo.Sys().(*syscall.Stat_t)
+ if !ok {
+ return fmt.Errorf("error checking file owner, sysInfo type is %T not *syscall.Stat_t", fileInfo.Sys())
+ }
+
+ // Get username from UID
+ brewUser, err := user.LookupId(fmt.Sprintf("%d", fileSysInfo.Uid))
+ if err != nil {
+ return fmt.Errorf("error looking up brew installer user: %w", err)
+ }
+ userName := brewUser.Username
+ // Get user HOME, required for brew to run correctly
+ // https://github.com/Homebrew/brew/issues/15833
+ homeDir := brewUser.HomeDir
+
+ // Check if netbird-ui is installed (must run as the brew user, not root)
+ checkUICmd := exec.CommandContext(ctx, "sudo", "-u", userName, brewBinPath, "list", "--formula", "netbirdio/tap/netbird-ui")
+ checkUICmd.Env = append(os.Environ(), "HOME="+homeDir)
+ uiInstalled := checkUICmd.Run() == nil
+
+ // Homebrew does not support installing specific versions
+ // Thus it will always update to latest and ignore targetVersion
+ upgradeArgs := []string{"-u", userName, brewBinPath, "upgrade", "netbirdio/tap/netbird"}
+ if uiInstalled {
+ upgradeArgs = append(upgradeArgs, "netbirdio/tap/netbird-ui")
+ }
+
+ cmd := exec.CommandContext(ctx, "sudo", upgradeArgs...)
+ cmd.Env = append(os.Environ(), "HOME="+homeDir)
+
+ if output, err := cmd.CombinedOutput(); err != nil {
+ return fmt.Errorf("error running brew upgrade: %w, output: %s", err, string(output))
+ }
+
+ log.Infof("homebrew updated successfully")
+ return nil
+}
+
+func (u *Installer) killUI() {
+ log.Infof("killing existing netbird-ui processes")
+ cmd := exec.Command("pkill", "-x", "netbird-ui")
+ if output, err := cmd.CombinedOutput(); err != nil {
+ // pkill returns exit code 1 if no processes matched, which is fine
+ log.Debugf("pkill netbird-ui result: %v, output: %s", err, string(output))
+ } else {
+ log.Infof("netbird-ui processes killed")
+ }
+}
+
+func urlWithVersionArch(_ Type, version string) string {
+ url := strings.ReplaceAll(pkgDownloadURL, "%version", version)
+ return strings.ReplaceAll(url, "%arch", runtime.GOARCH)
+}
diff --git a/client/internal/updatemanager/installer/installer_run_windows.go b/client/internal/updatemanager/installer/installer_run_windows.go
new file mode 100644
index 000000000..70c7e32cf
--- /dev/null
+++ b/client/internal/updatemanager/installer/installer_run_windows.go
@@ -0,0 +1,213 @@
+package installer
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "runtime"
+ "strings"
+ "time"
+ "unsafe"
+
+ log "github.com/sirupsen/logrus"
+ "golang.org/x/sys/windows"
+)
+
+const (
+ daemonName = "netbird.exe"
+ uiName = "netbird-ui.exe"
+ updaterBinary = "updater.exe"
+
+ msiLogFile = "msi.log"
+
+ msiDownloadURL = "https://github.com/netbirdio/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.msi"
+ exeDownloadURL = "https://github.com/netbirdio/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.exe"
+)
+
+var (
+ defaultTempDir = filepath.Join(os.Getenv("ProgramData"), "Netbird", "tmp-install")
+
+ // for the cleanup
+ binaryExtensions = []string{"msi", "exe"}
+)
+
+// Setup runs the installer with appropriate arguments and manages the daemon/UI state
+// This will be run by the updater process
+func (u *Installer) Setup(ctx context.Context, dryRun bool, installerFile string, daemonFolder string) (resultErr error) {
+ resultHandler := NewResultHandler(u.tempDir)
+
+ // Always ensure daemon and UI are restarted after setup
+ defer func() {
+ log.Infof("starting daemon back")
+ if err := u.startDaemon(daemonFolder); err != nil {
+ log.Errorf("failed to start daemon: %v", err)
+ }
+
+ log.Infof("starting UI back")
+ if err := u.startUIAsUser(daemonFolder); err != nil {
+ log.Errorf("failed to start UI: %v", err)
+ }
+
+ log.Infof("write out result")
+ var err error
+ if resultErr == nil {
+ err = resultHandler.WriteSuccess()
+ } else {
+ err = resultHandler.WriteErr(resultErr)
+ }
+ if err != nil {
+ log.Errorf("failed to write update result: %v", err)
+ }
+ }()
+
+ if dryRun {
+ log.Infof("dry-run mode enabled, skipping actual installation")
+ resultErr = fmt.Errorf("dry-run mode enabled")
+ return
+ }
+
+ installerType, err := typeByFileExtension(installerFile)
+ if err != nil {
+ log.Debugf("%v", err)
+ resultErr = err
+ return
+ }
+
+ var cmd *exec.Cmd
+ switch installerType {
+ case TypeExe:
+ log.Infof("run exe installer: %s", installerFile)
+ cmd = exec.CommandContext(ctx, installerFile, "/S")
+ default:
+ installerDir := filepath.Dir(installerFile)
+ logPath := filepath.Join(installerDir, msiLogFile)
+ log.Infof("run msi installer: %s", installerFile)
+ cmd = exec.CommandContext(ctx, "msiexec.exe", "/i", filepath.Base(installerFile), "/quiet", "/qn", "/l*v", logPath)
+ }
+
+ cmd.Dir = filepath.Dir(installerFile)
+
+ if resultErr = cmd.Start(); resultErr != nil {
+ log.Errorf("error starting installer: %v", resultErr)
+ return
+ }
+
+ log.Infof("installer started with PID %d", cmd.Process.Pid)
+ if resultErr = cmd.Wait(); resultErr != nil {
+ log.Errorf("installer process finished with error: %v", resultErr)
+ return
+ }
+
+ return nil
+}
+
+func (u *Installer) startDaemon(daemonFolder string) error {
+ log.Infof("starting netbird service")
+ ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
+ defer cancel()
+
+ cmd := exec.CommandContext(ctx, filepath.Join(daemonFolder, daemonName), "service", "start")
+ if output, err := cmd.CombinedOutput(); err != nil {
+ log.Debugf("failed to start netbird service: %v, output: %s", err, string(output))
+ return err
+ }
+ log.Infof("netbird service started successfully")
+ return nil
+}
+
+func (u *Installer) startUIAsUser(daemonFolder string) error {
+ uiPath := filepath.Join(daemonFolder, uiName)
+ log.Infof("starting netbird-ui: %s", uiPath)
+
+ // Get the active console session ID
+ sessionID := windows.WTSGetActiveConsoleSessionId()
+ if sessionID == 0xFFFFFFFF {
+ return fmt.Errorf("no active user session found")
+ }
+
+ // Get the user token for that session
+ var userToken windows.Token
+ err := windows.WTSQueryUserToken(sessionID, &userToken)
+ if err != nil {
+ return fmt.Errorf("failed to query user token: %w", err)
+ }
+ defer func() {
+ if err := userToken.Close(); err != nil {
+ log.Warnf("failed to close user token: %v", err)
+ }
+ }()
+
+ // Duplicate the token to a primary token
+ var primaryToken windows.Token
+ err = windows.DuplicateTokenEx(
+ userToken,
+ windows.MAXIMUM_ALLOWED,
+ nil,
+ windows.SecurityImpersonation,
+ windows.TokenPrimary,
+ &primaryToken,
+ )
+ if err != nil {
+ return fmt.Errorf("failed to duplicate token: %w", err)
+ }
+ defer func() {
+ if err := primaryToken.Close(); err != nil {
+ log.Warnf("failed to close token: %v", err)
+ }
+ }()
+
+ // Prepare startup info
+ var si windows.StartupInfo
+ si.Cb = uint32(unsafe.Sizeof(si))
+ si.Desktop = windows.StringToUTF16Ptr("winsta0\\default")
+
+ var pi windows.ProcessInformation
+
+ cmdLine, err := windows.UTF16PtrFromString(fmt.Sprintf("\"%s\"", uiPath))
+ if err != nil {
+ return fmt.Errorf("failed to convert path to UTF16: %w", err)
+ }
+
+ creationFlags := uint32(0x00000200 | 0x00000008 | 0x00000400) // CREATE_NEW_PROCESS_GROUP | DETACHED_PROCESS | CREATE_UNICODE_ENVIRONMENT
+
+ err = windows.CreateProcessAsUser(
+ primaryToken,
+ nil,
+ cmdLine,
+ nil,
+ nil,
+ false,
+ creationFlags,
+ nil,
+ nil,
+ &si,
+ &pi,
+ )
+ if err != nil {
+ return fmt.Errorf("CreateProcessAsUser failed: %w", err)
+ }
+
+ // Close handles
+ if err := windows.CloseHandle(pi.Process); err != nil {
+ log.Warnf("failed to close process handle: %v", err)
+ }
+ if err := windows.CloseHandle(pi.Thread); err != nil {
+ log.Warnf("failed to close thread handle: %v", err)
+ }
+
+ log.Infof("netbird-ui started successfully in session %d", sessionID)
+ return nil
+}
+
+func urlWithVersionArch(it Type, version string) string {
+ var url string
+ if it == TypeExe {
+ url = exeDownloadURL
+ } else {
+ url = msiDownloadURL
+ }
+ url = strings.ReplaceAll(url, "%version", version)
+ return strings.ReplaceAll(url, "%arch", runtime.GOARCH)
+}
diff --git a/client/internal/updatemanager/installer/log.go b/client/internal/updatemanager/installer/log.go
new file mode 100644
index 000000000..8b60dba28
--- /dev/null
+++ b/client/internal/updatemanager/installer/log.go
@@ -0,0 +1,5 @@
+package installer
+
+const (
+ LogFile = "installer.log"
+)
diff --git a/client/internal/updatemanager/installer/procattr_darwin.go b/client/internal/updatemanager/installer/procattr_darwin.go
new file mode 100644
index 000000000..56f2018bb
--- /dev/null
+++ b/client/internal/updatemanager/installer/procattr_darwin.go
@@ -0,0 +1,15 @@
+package installer
+
+import (
+ "os/exec"
+ "syscall"
+)
+
+// setUpdaterProcAttr configures the updater process to run in a new session,
+// making it independent of the parent daemon process. This ensures the updater
+// survives when the daemon is stopped during the pkg installation.
+func setUpdaterProcAttr(cmd *exec.Cmd) {
+ cmd.SysProcAttr = &syscall.SysProcAttr{
+ Setsid: true,
+ }
+}
diff --git a/client/internal/updatemanager/installer/procattr_windows.go b/client/internal/updatemanager/installer/procattr_windows.go
new file mode 100644
index 000000000..29a8a2de0
--- /dev/null
+++ b/client/internal/updatemanager/installer/procattr_windows.go
@@ -0,0 +1,14 @@
+package installer
+
+import (
+ "os/exec"
+ "syscall"
+)
+
+// setUpdaterProcAttr configures the updater process to run detached from the parent,
+// making it independent of the parent daemon process.
+func setUpdaterProcAttr(cmd *exec.Cmd) {
+ cmd.SysProcAttr = &syscall.SysProcAttr{
+ CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP | 0x00000008, // 0x00000008 is DETACHED_PROCESS
+ }
+}
diff --git a/client/internal/updatemanager/installer/repourl_dev.go b/client/internal/updatemanager/installer/repourl_dev.go
new file mode 100644
index 000000000..088821ad3
--- /dev/null
+++ b/client/internal/updatemanager/installer/repourl_dev.go
@@ -0,0 +1,7 @@
+//go:build devartifactsign
+
+package installer
+
+const (
+ DefaultSigningKeysBaseURL = "http://192.168.0.10:9089/signrepo"
+)
diff --git a/client/internal/updatemanager/installer/repourl_prod.go b/client/internal/updatemanager/installer/repourl_prod.go
new file mode 100644
index 000000000..abddc62c1
--- /dev/null
+++ b/client/internal/updatemanager/installer/repourl_prod.go
@@ -0,0 +1,7 @@
+//go:build !devartifactsign
+
+package installer
+
+const (
+ DefaultSigningKeysBaseURL = "https://publickeys.netbird.io/artifact-signatures"
+)
diff --git a/client/internal/updatemanager/installer/result.go b/client/internal/updatemanager/installer/result.go
new file mode 100644
index 000000000..03d08d527
--- /dev/null
+++ b/client/internal/updatemanager/installer/result.go
@@ -0,0 +1,230 @@
+package installer
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "os"
+ "path/filepath"
+ "time"
+
+ "github.com/fsnotify/fsnotify"
+ log "github.com/sirupsen/logrus"
+)
+
+const (
+ resultFile = "result.json"
+)
+
+type Result struct {
+ Success bool
+ Error string
+ ExecutedAt time.Time
+}
+
+// ResultHandler handles reading and writing update results
+type ResultHandler struct {
+ resultFile string
+}
+
+// NewResultHandler creates a new communicator with the given directory path
+// The result file will be created as "result.json" in the specified directory
+func NewResultHandler(installerDir string) *ResultHandler {
+ // Create it if it doesn't exist
+ // do not care if already exists
+ _ = os.MkdirAll(installerDir, 0o700)
+
+ rh := &ResultHandler{
+ resultFile: filepath.Join(installerDir, resultFile),
+ }
+ return rh
+}
+
+func (rh *ResultHandler) GetErrorResultReason() string {
+ result, err := rh.tryReadResult()
+ if err == nil && !result.Success {
+ return result.Error
+ }
+
+ if err := rh.cleanup(); err != nil {
+ log.Warnf("failed to cleanup result file: %v", err)
+ }
+
+ return ""
+}
+
+func (rh *ResultHandler) WriteSuccess() error {
+ result := Result{
+ Success: true,
+ ExecutedAt: time.Now(),
+ }
+ return rh.write(result)
+}
+
+func (rh *ResultHandler) WriteErr(errReason error) error {
+ result := Result{
+ Success: false,
+ Error: errReason.Error(),
+ ExecutedAt: time.Now(),
+ }
+ return rh.write(result)
+}
+
+func (rh *ResultHandler) Watch(ctx context.Context) (Result, error) {
+ log.Infof("start watching result: %s", rh.resultFile)
+
+ // Check if file already exists (updater finished before we started watching)
+ if result, err := rh.tryReadResult(); err == nil {
+ log.Infof("installer result: %v", result)
+ return result, nil
+ }
+
+ dir := filepath.Dir(rh.resultFile)
+
+ if err := rh.waitForDirectory(ctx, dir); err != nil {
+ return Result{}, err
+ }
+
+ return rh.watchForResultFile(ctx, dir)
+}
+
+func (rh *ResultHandler) waitForDirectory(ctx context.Context, dir string) error {
+ ticker := time.NewTicker(300 * time.Millisecond)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-ticker.C:
+ if info, err := os.Stat(dir); err == nil && info.IsDir() {
+ return nil
+ }
+ }
+ }
+}
+
+func (rh *ResultHandler) watchForResultFile(ctx context.Context, dir string) (Result, error) {
+ watcher, err := fsnotify.NewWatcher()
+ if err != nil {
+ log.Error(err)
+ return Result{}, err
+ }
+
+ defer func() {
+ if err := watcher.Close(); err != nil {
+ log.Warnf("failed to close watcher: %v", err)
+ }
+ }()
+
+ if err := watcher.Add(dir); err != nil {
+ return Result{}, fmt.Errorf("failed to watch directory: %v", err)
+ }
+
+ // Check again after setting up watcher to avoid race condition
+ // (file could have been created between initial check and watcher setup)
+ if result, err := rh.tryReadResult(); err == nil {
+ log.Infof("installer result: %v", result)
+ return result, nil
+ }
+
+ for {
+ select {
+ case <-ctx.Done():
+ return Result{}, ctx.Err()
+ case event, ok := <-watcher.Events:
+ if !ok {
+ return Result{}, errors.New("watcher closed unexpectedly")
+ }
+
+ if result, done := rh.handleWatchEvent(event); done {
+ return result, nil
+ }
+ case err, ok := <-watcher.Errors:
+ if !ok {
+ return Result{}, errors.New("watcher closed unexpectedly")
+ }
+ return Result{}, fmt.Errorf("watcher error: %w", err)
+ }
+ }
+}
+
+func (rh *ResultHandler) handleWatchEvent(event fsnotify.Event) (Result, bool) {
+ if event.Name != rh.resultFile {
+ return Result{}, false
+ }
+
+ if event.Has(fsnotify.Create) {
+ result, err := rh.tryReadResult()
+ if err != nil {
+ log.Debugf("error while reading result: %v", err)
+ return result, true
+ }
+ log.Infof("installer result: %v", result)
+ return result, true
+ }
+
+ return Result{}, false
+}
+
+// Write writes the update result to a file for the UI to read
+func (rh *ResultHandler) write(result Result) error {
+ log.Infof("write out installer result to: %s", rh.resultFile)
+ // Ensure directory exists
+ dir := filepath.Dir(rh.resultFile)
+ if err := os.MkdirAll(dir, 0o755); err != nil {
+ log.Errorf("failed to create directory %s: %v", dir, err)
+ return err
+ }
+
+ data, err := json.Marshal(result)
+ if err != nil {
+ return err
+ }
+
+ // Write to a temporary file first, then rename for atomic operation
+ tmpPath := rh.resultFile + ".tmp"
+ if err := os.WriteFile(tmpPath, data, 0o600); err != nil {
+ log.Errorf("failed to create temp file: %s", err)
+ return err
+ }
+
+ // Atomic rename
+ if err := os.Rename(tmpPath, rh.resultFile); err != nil {
+ if cleanupErr := os.Remove(tmpPath); cleanupErr != nil {
+ log.Warnf("Failed to remove temp result file: %v", err)
+ }
+ return err
+ }
+
+ return nil
+}
+
+func (rh *ResultHandler) cleanup() error {
+ err := os.Remove(rh.resultFile)
+ if err != nil && !os.IsNotExist(err) {
+ return err
+ }
+ log.Debugf("delete installer result file: %s", rh.resultFile)
+ return nil
+}
+
+// tryReadResult attempts to read and validate the result file
+func (rh *ResultHandler) tryReadResult() (Result, error) {
+ data, err := os.ReadFile(rh.resultFile)
+ if err != nil {
+ return Result{}, err
+ }
+
+ var result Result
+ if err := json.Unmarshal(data, &result); err != nil {
+ return Result{}, fmt.Errorf("invalid result format: %w", err)
+ }
+
+ if err := rh.cleanup(); err != nil {
+ log.Warnf("failed to cleanup result file: %v", err)
+ }
+
+ return result, nil
+}
diff --git a/client/internal/updatemanager/installer/types.go b/client/internal/updatemanager/installer/types.go
new file mode 100644
index 000000000..656d84f88
--- /dev/null
+++ b/client/internal/updatemanager/installer/types.go
@@ -0,0 +1,14 @@
+package installer
+
+type Type struct {
+ name string
+ downloadable bool
+}
+
+func (t Type) String() string {
+ return t.name
+}
+
+func (t Type) Downloadable() bool {
+ return t.downloadable
+}
diff --git a/client/internal/updatemanager/installer/types_darwin.go b/client/internal/updatemanager/installer/types_darwin.go
new file mode 100644
index 000000000..95a0cb737
--- /dev/null
+++ b/client/internal/updatemanager/installer/types_darwin.go
@@ -0,0 +1,22 @@
+package installer
+
+import (
+ "context"
+ "os/exec"
+)
+
+var (
+ TypeHomebrew = Type{name: "Homebrew", downloadable: false}
+ TypePKG = Type{name: "pkg", downloadable: true}
+)
+
+func TypeOfInstaller(ctx context.Context) Type {
+ cmd := exec.CommandContext(ctx, "pkgutil", "--pkg-info", "io.netbird.client")
+ _, err := cmd.Output()
+ if err != nil && cmd.ProcessState.ExitCode() == 1 {
+ // Not installed using pkg file, thus installed using Homebrew
+
+ return TypeHomebrew
+ }
+ return TypePKG
+}
diff --git a/client/internal/updatemanager/installer/types_windows.go b/client/internal/updatemanager/installer/types_windows.go
new file mode 100644
index 000000000..d4e5d83bd
--- /dev/null
+++ b/client/internal/updatemanager/installer/types_windows.go
@@ -0,0 +1,51 @@
+package installer
+
+import (
+ "context"
+ "fmt"
+ "strings"
+
+ log "github.com/sirupsen/logrus"
+ "golang.org/x/sys/windows/registry"
+)
+
+const (
+ uninstallKeyPath64 = `SOFTWARE\WOW6432Node\Microsoft\Windows\CurrentVersion\Uninstall\Netbird`
+ uninstallKeyPath32 = `SOFTWARE\Microsoft\Windows\CurrentVersion\Uninstall\Netbird`
+)
+
+var (
+ TypeExe = Type{name: "EXE", downloadable: true}
+ TypeMSI = Type{name: "MSI", downloadable: true}
+)
+
+func TypeOfInstaller(_ context.Context) Type {
+ paths := []string{uninstallKeyPath64, uninstallKeyPath32}
+
+ for _, path := range paths {
+ k, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE)
+ if err != nil {
+ continue
+ }
+
+ if err := k.Close(); err != nil {
+ log.Warnf("Error closing registry key: %v", err)
+ }
+ return TypeExe
+
+ }
+
+ log.Debug("No registry entry found for Netbird, assuming MSI installation")
+ return TypeMSI
+}
+
+func typeByFileExtension(filePath string) (Type, error) {
+ switch {
+ case strings.HasSuffix(strings.ToLower(filePath), ".exe"):
+ return TypeExe, nil
+ case strings.HasSuffix(strings.ToLower(filePath), ".msi"):
+ return TypeMSI, nil
+ default:
+ return Type{}, fmt.Errorf("unsupported installer type for file: %s", filePath)
+ }
+}
diff --git a/client/internal/updatemanager/manager.go b/client/internal/updatemanager/manager.go
new file mode 100644
index 000000000..eae11de56
--- /dev/null
+++ b/client/internal/updatemanager/manager.go
@@ -0,0 +1,374 @@
+//go:build windows || darwin
+
+package updatemanager
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "runtime"
+ "sync"
+ "time"
+
+ v "github.com/hashicorp/go-version"
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/client/internal/peer"
+ "github.com/netbirdio/netbird/client/internal/statemanager"
+ "github.com/netbirdio/netbird/client/internal/updatemanager/installer"
+ cProto "github.com/netbirdio/netbird/client/proto"
+ "github.com/netbirdio/netbird/version"
+)
+
+const (
+ latestVersion = "latest"
+ // this version will be ignored
+ developmentVersion = "development"
+)
+
+var errNoUpdateState = errors.New("no update state found")
+
+type UpdateState struct {
+ PreUpdateVersion string
+ TargetVersion string
+}
+
+func (u UpdateState) Name() string {
+ return "autoUpdate"
+}
+
+type Manager struct {
+ statusRecorder *peer.Status
+ stateManager *statemanager.Manager
+
+ lastTrigger time.Time
+ mgmUpdateChan chan struct{}
+ updateChannel chan struct{}
+ currentVersion string
+ update UpdateInterface
+ wg sync.WaitGroup
+
+ cancel context.CancelFunc
+
+ expectedVersion *v.Version
+ updateToLatestVersion bool
+
+ // updateMutex protect update and expectedVersion fields
+ updateMutex sync.Mutex
+
+ triggerUpdateFn func(context.Context, string) error
+}
+
+func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
+ if runtime.GOOS == "darwin" {
+ isBrew := !installer.TypeOfInstaller(context.Background()).Downloadable()
+ if isBrew {
+ log.Warnf("auto-update disabled on Home Brew installation")
+ return nil, fmt.Errorf("auto-update not supported on Home Brew installation yet")
+ }
+ }
+ return newManager(statusRecorder, stateManager)
+}
+
+func newManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
+ manager := &Manager{
+ statusRecorder: statusRecorder,
+ stateManager: stateManager,
+ mgmUpdateChan: make(chan struct{}, 1),
+ updateChannel: make(chan struct{}, 1),
+ currentVersion: version.NetbirdVersion(),
+ update: version.NewUpdate("nb/client"),
+ }
+ manager.triggerUpdateFn = manager.triggerUpdate
+
+ stateManager.RegisterState(&UpdateState{})
+
+ return manager, nil
+}
+
+// CheckUpdateSuccess checks if the update was successful and send a notification.
+// It works without to start the update manager.
+func (m *Manager) CheckUpdateSuccess(ctx context.Context) {
+ reason := m.lastResultErrReason()
+ if reason != "" {
+ m.statusRecorder.PublishEvent(
+ cProto.SystemEvent_ERROR,
+ cProto.SystemEvent_SYSTEM,
+ "Auto-update failed",
+ fmt.Sprintf("Auto-update failed: %s", reason),
+ nil,
+ )
+ }
+
+ updateState, err := m.loadAndDeleteUpdateState(ctx)
+ if err != nil {
+ if errors.Is(err, errNoUpdateState) {
+ return
+ }
+ log.Errorf("failed to load update state: %v", err)
+ return
+ }
+
+ log.Debugf("auto-update state loaded, %v", *updateState)
+
+ if updateState.TargetVersion == m.currentVersion {
+ m.statusRecorder.PublishEvent(
+ cProto.SystemEvent_INFO,
+ cProto.SystemEvent_SYSTEM,
+ "Auto-update completed",
+ fmt.Sprintf("Your NetBird Client was auto-updated to version %s", m.currentVersion),
+ nil,
+ )
+ return
+ }
+}
+
+func (m *Manager) Start(ctx context.Context) {
+ if m.cancel != nil {
+ log.Errorf("Manager already started")
+ return
+ }
+
+ m.update.SetDaemonVersion(m.currentVersion)
+ m.update.SetOnUpdateListener(func() {
+ select {
+ case m.updateChannel <- struct{}{}:
+ default:
+ }
+ })
+ go m.update.StartFetcher()
+
+ ctx, cancel := context.WithCancel(ctx)
+ m.cancel = cancel
+
+ m.wg.Add(1)
+ go m.updateLoop(ctx)
+}
+
+func (m *Manager) SetVersion(expectedVersion string) {
+ log.Infof("set expected agent version for upgrade: %s", expectedVersion)
+ if m.cancel == nil {
+ log.Errorf("manager not started")
+ return
+ }
+
+ m.updateMutex.Lock()
+ defer m.updateMutex.Unlock()
+
+ if expectedVersion == "" {
+ log.Errorf("empty expected version provided")
+ m.expectedVersion = nil
+ m.updateToLatestVersion = false
+ return
+ }
+
+ if expectedVersion == latestVersion {
+ m.updateToLatestVersion = true
+ m.expectedVersion = nil
+ } else {
+ expectedSemVer, err := v.NewVersion(expectedVersion)
+ if err != nil {
+ log.Errorf("error parsing version: %v", err)
+ return
+ }
+ if m.expectedVersion != nil && m.expectedVersion.Equal(expectedSemVer) {
+ return
+ }
+ m.expectedVersion = expectedSemVer
+ m.updateToLatestVersion = false
+ }
+
+ select {
+ case m.mgmUpdateChan <- struct{}{}:
+ default:
+ }
+}
+
+func (m *Manager) Stop() {
+ if m.cancel == nil {
+ return
+ }
+
+ m.cancel()
+ m.updateMutex.Lock()
+ if m.update != nil {
+ m.update.StopWatch()
+ m.update = nil
+ }
+ m.updateMutex.Unlock()
+
+ m.wg.Wait()
+}
+
+func (m *Manager) onContextCancel() {
+ if m.cancel == nil {
+ return
+ }
+
+ m.updateMutex.Lock()
+ defer m.updateMutex.Unlock()
+ if m.update != nil {
+ m.update.StopWatch()
+ m.update = nil
+ }
+}
+
+func (m *Manager) updateLoop(ctx context.Context) {
+ defer m.wg.Done()
+
+ for {
+ select {
+ case <-ctx.Done():
+ m.onContextCancel()
+ return
+ case <-m.mgmUpdateChan:
+ case <-m.updateChannel:
+ log.Infof("fetched new version info")
+ }
+
+ m.handleUpdate(ctx)
+ }
+}
+
+func (m *Manager) handleUpdate(ctx context.Context) {
+ var updateVersion *v.Version
+
+ m.updateMutex.Lock()
+ if m.update == nil {
+ m.updateMutex.Unlock()
+ return
+ }
+
+ expectedVersion := m.expectedVersion
+ useLatest := m.updateToLatestVersion
+ curLatestVersion := m.update.LatestVersion()
+ m.updateMutex.Unlock()
+
+ switch {
+ // Resolve "latest" to actual version
+ case useLatest:
+ if curLatestVersion == nil {
+ log.Tracef("latest version not fetched yet")
+ return
+ }
+ updateVersion = curLatestVersion
+ // Update to specific version
+ case expectedVersion != nil:
+ updateVersion = expectedVersion
+ default:
+ log.Debugf("no expected version information set")
+ return
+ }
+
+ log.Debugf("checking update option, current version: %s, target version: %s", m.currentVersion, updateVersion)
+ if !m.shouldUpdate(updateVersion) {
+ return
+ }
+
+ m.lastTrigger = time.Now()
+ log.Infof("Auto-update triggered, current version: %s, target version: %s", m.currentVersion, updateVersion)
+ m.statusRecorder.PublishEvent(
+ cProto.SystemEvent_CRITICAL,
+ cProto.SystemEvent_SYSTEM,
+ "Automatically updating client",
+ "Your client version is older than auto-update version set in Management, updating client now.",
+ nil,
+ )
+
+ m.statusRecorder.PublishEvent(
+ cProto.SystemEvent_CRITICAL,
+ cProto.SystemEvent_SYSTEM,
+ "",
+ "",
+ map[string]string{"progress_window": "show", "version": updateVersion.String()},
+ )
+
+ updateState := UpdateState{
+ PreUpdateVersion: m.currentVersion,
+ TargetVersion: updateVersion.String(),
+ }
+
+ if err := m.stateManager.UpdateState(updateState); err != nil {
+ log.Warnf("failed to update state: %v", err)
+ } else {
+ if err = m.stateManager.PersistState(ctx); err != nil {
+ log.Warnf("failed to persist state: %v", err)
+ }
+ }
+
+ if err := m.triggerUpdateFn(ctx, updateVersion.String()); err != nil {
+ log.Errorf("Error triggering auto-update: %v", err)
+ m.statusRecorder.PublishEvent(
+ cProto.SystemEvent_ERROR,
+ cProto.SystemEvent_SYSTEM,
+ "Auto-update failed",
+ fmt.Sprintf("Auto-update failed: %v", err),
+ nil,
+ )
+ }
+}
+
+// loadAndDeleteUpdateState loads the update state, deletes it from storage, and returns it.
+// Returns nil if no state exists.
+func (m *Manager) loadAndDeleteUpdateState(ctx context.Context) (*UpdateState, error) {
+ stateType := &UpdateState{}
+
+ m.stateManager.RegisterState(stateType)
+ if err := m.stateManager.LoadState(stateType); err != nil {
+ return nil, fmt.Errorf("load state: %w", err)
+ }
+
+ state := m.stateManager.GetState(stateType)
+ if state == nil {
+ return nil, errNoUpdateState
+ }
+
+ updateState, ok := state.(*UpdateState)
+ if !ok {
+ return nil, fmt.Errorf("failed to cast state to UpdateState")
+ }
+
+ if err := m.stateManager.DeleteState(updateState); err != nil {
+ return nil, fmt.Errorf("delete state: %w", err)
+ }
+
+ if err := m.stateManager.PersistState(ctx); err != nil {
+ return nil, fmt.Errorf("persist state: %w", err)
+ }
+
+ return updateState, nil
+}
+
+func (m *Manager) shouldUpdate(updateVersion *v.Version) bool {
+ if m.currentVersion == developmentVersion {
+ log.Debugf("skipping auto-update, running development version")
+ return false
+ }
+ currentVersion, err := v.NewVersion(m.currentVersion)
+ if err != nil {
+ log.Errorf("error checking for update, error parsing version `%s`: %v", m.currentVersion, err)
+ return false
+ }
+ if currentVersion.GreaterThanOrEqual(updateVersion) {
+ log.Infof("current version (%s) is equal to or higher than auto-update version (%s)", m.currentVersion, updateVersion)
+ return false
+ }
+
+ if time.Since(m.lastTrigger) < 5*time.Minute {
+ log.Debugf("skipping auto-update, last update was %s ago", time.Since(m.lastTrigger))
+ return false
+ }
+
+ return true
+}
+
+func (m *Manager) lastResultErrReason() string {
+ inst := installer.New()
+ result := installer.NewResultHandler(inst.TempDir())
+ return result.GetErrorResultReason()
+}
+
+func (m *Manager) triggerUpdate(ctx context.Context, targetVersion string) error {
+ inst := installer.New()
+ return inst.RunInstallation(ctx, targetVersion)
+}
diff --git a/client/internal/updatemanager/manager_test.go b/client/internal/updatemanager/manager_test.go
new file mode 100644
index 000000000..20ddec10d
--- /dev/null
+++ b/client/internal/updatemanager/manager_test.go
@@ -0,0 +1,214 @@
+//go:build windows || darwin
+
+package updatemanager
+
+import (
+ "context"
+ "fmt"
+ "path"
+ "testing"
+ "time"
+
+ v "github.com/hashicorp/go-version"
+
+ "github.com/netbirdio/netbird/client/internal/peer"
+ "github.com/netbirdio/netbird/client/internal/statemanager"
+)
+
+type versionUpdateMock struct {
+ latestVersion *v.Version
+ onUpdate func()
+}
+
+func (v versionUpdateMock) StopWatch() {}
+
+func (v versionUpdateMock) SetDaemonVersion(newVersion string) bool {
+ return false
+}
+
+func (v *versionUpdateMock) SetOnUpdateListener(updateFn func()) {
+ v.onUpdate = updateFn
+}
+
+func (v versionUpdateMock) LatestVersion() *v.Version {
+ return v.latestVersion
+}
+
+func (v versionUpdateMock) StartFetcher() {}
+
+func Test_LatestVersion(t *testing.T) {
+ testMatrix := []struct {
+ name string
+ daemonVersion string
+ initialLatestVersion *v.Version
+ latestVersion *v.Version
+ shouldUpdateInit bool
+ shouldUpdateLater bool
+ }{
+ {
+ name: "Should only trigger update once due to time between triggers being < 5 Minutes",
+ daemonVersion: "1.0.0",
+ initialLatestVersion: v.Must(v.NewSemver("1.0.1")),
+ latestVersion: v.Must(v.NewSemver("1.0.2")),
+ shouldUpdateInit: true,
+ shouldUpdateLater: false,
+ },
+ {
+ name: "Shouldn't update initially, but should update as soon as latest version is fetched",
+ daemonVersion: "1.0.0",
+ initialLatestVersion: nil,
+ latestVersion: v.Must(v.NewSemver("1.0.1")),
+ shouldUpdateInit: false,
+ shouldUpdateLater: true,
+ },
+ }
+
+ for idx, c := range testMatrix {
+ mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion}
+ tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
+ m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile))
+ m.update = mockUpdate
+
+ targetVersionChan := make(chan string, 1)
+
+ m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error {
+ targetVersionChan <- targetVersion
+ return nil
+ }
+ m.currentVersion = c.daemonVersion
+ m.Start(context.Background())
+ m.SetVersion("latest")
+ var triggeredInit bool
+ select {
+ case targetVersion := <-targetVersionChan:
+ if targetVersion != c.initialLatestVersion.String() {
+ t.Errorf("%s: Initial update version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), targetVersion)
+ }
+ triggeredInit = true
+ case <-time.After(10 * time.Millisecond):
+ triggeredInit = false
+ }
+ if triggeredInit != c.shouldUpdateInit {
+ t.Errorf("%s: Initial update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit)
+ }
+
+ mockUpdate.latestVersion = c.latestVersion
+ mockUpdate.onUpdate()
+
+ var triggeredLater bool
+ select {
+ case targetVersion := <-targetVersionChan:
+ if targetVersion != c.latestVersion.String() {
+ t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion)
+ }
+ triggeredLater = true
+ case <-time.After(10 * time.Millisecond):
+ triggeredLater = false
+ }
+ if triggeredLater != c.shouldUpdateLater {
+ t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater)
+ }
+
+ m.Stop()
+ }
+}
+
+func Test_HandleUpdate(t *testing.T) {
+ testMatrix := []struct {
+ name string
+ daemonVersion string
+ latestVersion *v.Version
+ expectedVersion string
+ shouldUpdate bool
+ }{
+ {
+ name: "Update to a specific version should update regardless of if latestVersion is available yet",
+ daemonVersion: "0.55.0",
+ latestVersion: nil,
+ expectedVersion: "0.56.0",
+ shouldUpdate: true,
+ },
+ {
+ name: "Update to specific version should not update if version matches",
+ daemonVersion: "0.55.0",
+ latestVersion: nil,
+ expectedVersion: "0.55.0",
+ shouldUpdate: false,
+ },
+ {
+ name: "Update to specific version should not update if current version is newer",
+ daemonVersion: "0.55.0",
+ latestVersion: nil,
+ expectedVersion: "0.54.0",
+ shouldUpdate: false,
+ },
+ {
+ name: "Update to latest version should update if latest is newer",
+ daemonVersion: "0.55.0",
+ latestVersion: v.Must(v.NewSemver("0.56.0")),
+ expectedVersion: "latest",
+ shouldUpdate: true,
+ },
+ {
+ name: "Update to latest version should not update if latest == current",
+ daemonVersion: "0.56.0",
+ latestVersion: v.Must(v.NewSemver("0.56.0")),
+ expectedVersion: "latest",
+ shouldUpdate: false,
+ },
+ {
+ name: "Should not update if daemon version is invalid",
+ daemonVersion: "development",
+ latestVersion: v.Must(v.NewSemver("1.0.0")),
+ expectedVersion: "latest",
+ shouldUpdate: false,
+ },
+ {
+ name: "Should not update if expecting latest and latest version is unavailable",
+ daemonVersion: "0.55.0",
+ latestVersion: nil,
+ expectedVersion: "latest",
+ shouldUpdate: false,
+ },
+ {
+ name: "Should not update if expected version is invalid",
+ daemonVersion: "0.55.0",
+ latestVersion: nil,
+ expectedVersion: "development",
+ shouldUpdate: false,
+ },
+ }
+ for idx, c := range testMatrix {
+ tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
+ m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile))
+ m.update = &versionUpdateMock{latestVersion: c.latestVersion}
+ targetVersionChan := make(chan string, 1)
+
+ m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error {
+ targetVersionChan <- targetVersion
+ return nil
+ }
+
+ m.currentVersion = c.daemonVersion
+ m.Start(context.Background())
+ m.SetVersion(c.expectedVersion)
+
+ var updateTriggered bool
+ select {
+ case targetVersion := <-targetVersionChan:
+ if c.expectedVersion == "latest" && targetVersion != c.latestVersion.String() {
+ t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion)
+ } else if c.expectedVersion != "latest" && targetVersion != c.expectedVersion {
+ t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.expectedVersion, targetVersion)
+ }
+ updateTriggered = true
+ case <-time.After(10 * time.Millisecond):
+ updateTriggered = false
+ }
+
+ if updateTriggered != c.shouldUpdate {
+ t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdate, updateTriggered)
+ }
+ m.Stop()
+ }
+}
diff --git a/client/internal/updatemanager/manager_unsupported.go b/client/internal/updatemanager/manager_unsupported.go
new file mode 100644
index 000000000..4e87c2d77
--- /dev/null
+++ b/client/internal/updatemanager/manager_unsupported.go
@@ -0,0 +1,39 @@
+//go:build !windows && !darwin
+
+package updatemanager
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/netbirdio/netbird/client/internal/peer"
+ "github.com/netbirdio/netbird/client/internal/statemanager"
+)
+
+// Manager is a no-op stub for unsupported platforms
+type Manager struct{}
+
+// NewManager returns a no-op manager for unsupported platforms
+func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
+ return nil, fmt.Errorf("update manager is not supported on this platform")
+}
+
+// CheckUpdateSuccess is a no-op on unsupported platforms
+func (m *Manager) CheckUpdateSuccess(ctx context.Context) {
+ // no-op
+}
+
+// Start is a no-op on unsupported platforms
+func (m *Manager) Start(ctx context.Context) {
+ // no-op
+}
+
+// SetVersion is a no-op on unsupported platforms
+func (m *Manager) SetVersion(expectedVersion string) {
+ // no-op
+}
+
+// Stop is a no-op on unsupported platforms
+func (m *Manager) Stop() {
+ // no-op
+}
diff --git a/client/internal/updatemanager/reposign/artifact.go b/client/internal/updatemanager/reposign/artifact.go
new file mode 100644
index 000000000..3d4fe9c74
--- /dev/null
+++ b/client/internal/updatemanager/reposign/artifact.go
@@ -0,0 +1,302 @@
+package reposign
+
+import (
+ "crypto/ed25519"
+ "crypto/rand"
+ "encoding/binary"
+ "encoding/json"
+ "encoding/pem"
+ "errors"
+ "fmt"
+ "hash"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+ "golang.org/x/crypto/blake2s"
+)
+
+const (
+ tagArtifactPrivate = "ARTIFACT PRIVATE KEY"
+ tagArtifactPublic = "ARTIFACT PUBLIC KEY"
+
+ maxArtifactKeySignatureAge = 10 * 365 * 24 * time.Hour
+ maxArtifactSignatureAge = 10 * 365 * 24 * time.Hour
+)
+
+// ArtifactHash wraps a hash.Hash and counts bytes written
+type ArtifactHash struct {
+ hash.Hash
+}
+
+// NewArtifactHash returns an initialized ArtifactHash using BLAKE2s
+func NewArtifactHash() *ArtifactHash {
+ h, err := blake2s.New256(nil)
+ if err != nil {
+ panic(err) // Should never happen with nil Key
+ }
+ return &ArtifactHash{Hash: h}
+}
+
+func (ah *ArtifactHash) Write(b []byte) (int, error) {
+ return ah.Hash.Write(b)
+}
+
+// ArtifactKey is a signing Key used to sign artifacts
+type ArtifactKey struct {
+ PrivateKey
+}
+
+func (k ArtifactKey) String() string {
+ return fmt.Sprintf(
+ "ArtifactKey[ID=%s, CreatedAt=%s, ExpiresAt=%s]",
+ k.Metadata.ID,
+ k.Metadata.CreatedAt.Format(time.RFC3339),
+ k.Metadata.ExpiresAt.Format(time.RFC3339),
+ )
+}
+
+func GenerateArtifactKey(rootKey *RootKey, expiration time.Duration) (*ArtifactKey, []byte, []byte, []byte, error) {
+ // Verify root key is still valid
+ if !rootKey.Metadata.ExpiresAt.IsZero() && time.Now().After(rootKey.Metadata.ExpiresAt) {
+ return nil, nil, nil, nil, fmt.Errorf("root key has expired on %s", rootKey.Metadata.ExpiresAt.Format(time.RFC3339))
+ }
+
+ now := time.Now()
+ expirationTime := now.Add(expiration)
+ pub, priv, err := ed25519.GenerateKey(rand.Reader)
+ if err != nil {
+ return nil, nil, nil, nil, fmt.Errorf("generate ed25519 key: %w", err)
+ }
+
+ metadata := KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: now.UTC(),
+ ExpiresAt: expirationTime.UTC(),
+ }
+
+ ak := &ArtifactKey{
+ PrivateKey{
+ Key: priv,
+ Metadata: metadata,
+ },
+ }
+
+ // Marshal PrivateKey struct to JSON
+ privJSON, err := json.Marshal(ak.PrivateKey)
+ if err != nil {
+ return nil, nil, nil, nil, fmt.Errorf("failed to marshal private key: %w", err)
+ }
+
+ // Marshal PublicKey struct to JSON
+ pubKey := PublicKey{
+ Key: pub,
+ Metadata: metadata,
+ }
+ pubJSON, err := json.Marshal(pubKey)
+ if err != nil {
+ return nil, nil, nil, nil, fmt.Errorf("failed to marshal public key: %w", err)
+ }
+
+ // Encode to PEM with metadata embedded in bytes
+ privPEM := pem.EncodeToMemory(&pem.Block{
+ Type: tagArtifactPrivate,
+ Bytes: privJSON,
+ })
+
+ pubPEM := pem.EncodeToMemory(&pem.Block{
+ Type: tagArtifactPublic,
+ Bytes: pubJSON,
+ })
+
+ // Sign the public key with the root key
+ signature, err := SignArtifactKey(*rootKey, pubPEM)
+ if err != nil {
+ return nil, nil, nil, nil, fmt.Errorf("failed to sign artifact key: %w", err)
+ }
+
+ return ak, privPEM, pubPEM, signature, nil
+}
+
+func ParseArtifactKey(privKeyPEM []byte) (ArtifactKey, error) {
+ pk, err := parsePrivateKey(privKeyPEM, tagArtifactPrivate)
+ if err != nil {
+ return ArtifactKey{}, fmt.Errorf("failed to parse artifact Key: %w", err)
+ }
+ return ArtifactKey{pk}, nil
+}
+
+func ParseArtifactPubKey(data []byte) (PublicKey, error) {
+ pk, _, err := parsePublicKey(data, tagArtifactPublic)
+ return pk, err
+}
+
+func BundleArtifactKeys(rootKey *RootKey, keys []PublicKey) ([]byte, []byte, error) {
+ if len(keys) == 0 {
+ return nil, nil, errors.New("no keys to bundle")
+ }
+
+ // Create bundle by concatenating PEM-encoded keys
+ var pubBundle []byte
+
+ for _, pk := range keys {
+ // Marshal PublicKey struct to JSON
+ pubJSON, err := json.Marshal(pk)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to marshal public key: %w", err)
+ }
+
+ // Encode to PEM
+ pubPEM := pem.EncodeToMemory(&pem.Block{
+ Type: tagArtifactPublic,
+ Bytes: pubJSON,
+ })
+
+ pubBundle = append(pubBundle, pubPEM...)
+ }
+
+ // Sign the entire bundle with the root key
+ signature, err := SignArtifactKey(*rootKey, pubBundle)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to sign artifact key bundle: %w", err)
+ }
+
+ return pubBundle, signature, nil
+}
+
+func ValidateArtifactKeys(publicRootKeys []PublicKey, data []byte, signature Signature, revocationList *RevocationList) ([]PublicKey, error) {
+ now := time.Now().UTC()
+ if signature.Timestamp.After(now.Add(maxClockSkew)) {
+ err := fmt.Errorf("signature timestamp is in the future: %v", signature.Timestamp)
+ log.Debugf("artifact signature error: %v", err)
+ return nil, err
+ }
+ if now.Sub(signature.Timestamp) > maxArtifactKeySignatureAge {
+ err := fmt.Errorf("signature is too old: %v (created %v)", now.Sub(signature.Timestamp), signature.Timestamp)
+ log.Debugf("artifact signature error: %v", err)
+ return nil, err
+ }
+
+ // Reconstruct the signed message: artifact_key_data || timestamp
+ msg := make([]byte, 0, len(data)+8)
+ msg = append(msg, data...)
+ msg = binary.LittleEndian.AppendUint64(msg, uint64(signature.Timestamp.Unix()))
+
+ if !verifyAny(publicRootKeys, msg, signature.Signature) {
+ return nil, errors.New("failed to verify signature of artifact keys")
+ }
+
+ pubKeys, err := parsePublicKeyBundle(data, tagArtifactPublic)
+ if err != nil {
+ log.Debugf("failed to parse public keys: %s", err)
+ return nil, err
+ }
+
+ validKeys := make([]PublicKey, 0, len(pubKeys))
+ for _, pubKey := range pubKeys {
+ // Filter out expired keys
+ if !pubKey.Metadata.ExpiresAt.IsZero() && now.After(pubKey.Metadata.ExpiresAt) {
+ log.Debugf("Key %s is expired at %v (current time %v)",
+ pubKey.Metadata.ID, pubKey.Metadata.ExpiresAt, now)
+ continue
+ }
+
+ if revocationList != nil {
+ if revTime, revoked := revocationList.Revoked[pubKey.Metadata.ID]; revoked {
+ log.Debugf("Key %s is revoked as of %v (created %v)",
+ pubKey.Metadata.ID, revTime, pubKey.Metadata.CreatedAt)
+ continue
+ }
+ }
+ validKeys = append(validKeys, pubKey)
+ }
+
+ if len(validKeys) == 0 {
+ log.Debugf("no valid public keys found for artifact keys")
+ return nil, fmt.Errorf("all %d artifact keys are revoked", len(pubKeys))
+ }
+
+ return validKeys, nil
+}
+
+func ValidateArtifact(artifactPubKeys []PublicKey, data []byte, signature Signature) error {
+ // Validate signature timestamp
+ now := time.Now().UTC()
+ if signature.Timestamp.After(now.Add(maxClockSkew)) {
+ err := fmt.Errorf("artifact signature timestamp is in the future: %v", signature.Timestamp)
+ log.Debugf("failed to verify signature of artifact: %s", err)
+ return err
+ }
+ if now.Sub(signature.Timestamp) > maxArtifactSignatureAge {
+ return fmt.Errorf("artifact signature is too old: %v (created %v)",
+ now.Sub(signature.Timestamp), signature.Timestamp)
+ }
+
+ h := NewArtifactHash()
+ if _, err := h.Write(data); err != nil {
+ return fmt.Errorf("failed to hash artifact: %w", err)
+ }
+ hash := h.Sum(nil)
+
+ // Reconstruct the signed message: hash || length || timestamp
+ msg := make([]byte, 0, len(hash)+8+8)
+ msg = append(msg, hash...)
+ msg = binary.LittleEndian.AppendUint64(msg, uint64(len(data)))
+ msg = binary.LittleEndian.AppendUint64(msg, uint64(signature.Timestamp.Unix()))
+
+ // Find matching Key and verify
+ for _, keyInfo := range artifactPubKeys {
+ if keyInfo.Metadata.ID == signature.KeyID {
+ // Check Key expiration
+ if !keyInfo.Metadata.ExpiresAt.IsZero() &&
+ signature.Timestamp.After(keyInfo.Metadata.ExpiresAt) {
+ return fmt.Errorf("signing Key %s expired at %v, signature from %v",
+ signature.KeyID, keyInfo.Metadata.ExpiresAt, signature.Timestamp)
+ }
+
+ if ed25519.Verify(keyInfo.Key, msg, signature.Signature) {
+ log.Debugf("artifact verified successfully with Key: %s", signature.KeyID)
+ return nil
+ }
+ return fmt.Errorf("signature verification failed for Key %s", signature.KeyID)
+ }
+ }
+
+ return fmt.Errorf("no signing Key found with ID %s", signature.KeyID)
+}
+
+func SignData(artifactKey ArtifactKey, data []byte) ([]byte, error) {
+ if len(data) == 0 { // Check happens too late
+ return nil, fmt.Errorf("artifact length must be positive, got %d", len(data))
+ }
+
+ h := NewArtifactHash()
+ if _, err := h.Write(data); err != nil {
+ return nil, fmt.Errorf("failed to write artifact hash: %w", err)
+ }
+
+ timestamp := time.Now().UTC()
+
+ if !artifactKey.Metadata.ExpiresAt.IsZero() && timestamp.After(artifactKey.Metadata.ExpiresAt) {
+ return nil, fmt.Errorf("artifact key expired at %v", artifactKey.Metadata.ExpiresAt)
+ }
+
+ hash := h.Sum(nil)
+
+ // Create message: hash || length || timestamp
+ msg := make([]byte, 0, len(hash)+8+8)
+ msg = append(msg, hash...)
+ msg = binary.LittleEndian.AppendUint64(msg, uint64(len(data)))
+ msg = binary.LittleEndian.AppendUint64(msg, uint64(timestamp.Unix()))
+
+ sig := ed25519.Sign(artifactKey.Key, msg)
+
+ bundle := Signature{
+ Signature: sig,
+ Timestamp: timestamp,
+ KeyID: artifactKey.Metadata.ID,
+ Algorithm: "ed25519",
+ HashAlgo: "blake2s",
+ }
+
+ return json.Marshal(bundle)
+}
diff --git a/client/internal/updatemanager/reposign/artifact_test.go b/client/internal/updatemanager/reposign/artifact_test.go
new file mode 100644
index 000000000..8865e2d0a
--- /dev/null
+++ b/client/internal/updatemanager/reposign/artifact_test.go
@@ -0,0 +1,1080 @@
+package reposign
+
+import (
+ "crypto/ed25519"
+ "crypto/rand"
+ "encoding/json"
+ "encoding/pem"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// Test ArtifactHash
+
+func TestNewArtifactHash(t *testing.T) {
+ h := NewArtifactHash()
+ assert.NotNil(t, h)
+ assert.NotNil(t, h.Hash)
+}
+
+func TestArtifactHash_Write(t *testing.T) {
+ h := NewArtifactHash()
+
+ data := []byte("test data")
+ n, err := h.Write(data)
+ require.NoError(t, err)
+ assert.Equal(t, len(data), n)
+
+ hash := h.Sum(nil)
+ assert.NotEmpty(t, hash)
+ assert.Equal(t, 32, len(hash)) // BLAKE2s-256
+}
+
+func TestArtifactHash_Deterministic(t *testing.T) {
+ data := []byte("test data")
+
+ h1 := NewArtifactHash()
+ if _, err := h1.Write(data); err != nil {
+ t.Fatal(err)
+ }
+ hash1 := h1.Sum(nil)
+
+ h2 := NewArtifactHash()
+ if _, err := h2.Write(data); err != nil {
+ t.Fatal(err)
+ }
+ hash2 := h2.Sum(nil)
+
+ assert.Equal(t, hash1, hash2)
+}
+
+func TestArtifactHash_DifferentData(t *testing.T) {
+ h1 := NewArtifactHash()
+ if _, err := h1.Write([]byte("data1")); err != nil {
+ t.Fatal(err)
+ }
+ hash1 := h1.Sum(nil)
+
+ h2 := NewArtifactHash()
+ if _, err := h2.Write([]byte("data2")); err != nil {
+ t.Fatal(err)
+ }
+ hash2 := h2.Sum(nil)
+
+ assert.NotEqual(t, hash1, hash2)
+}
+
+// Test ArtifactKey.String()
+
+func TestArtifactKey_String(t *testing.T) {
+ pub, priv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ createdAt := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
+ expiresAt := time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC)
+
+ ak := ArtifactKey{
+ PrivateKey{
+ Key: priv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: createdAt,
+ ExpiresAt: expiresAt,
+ },
+ },
+ }
+
+ str := ak.String()
+ assert.Contains(t, str, "ArtifactKey")
+ assert.Contains(t, str, computeKeyID(pub).String())
+ assert.Contains(t, str, "2024-01-15")
+ assert.Contains(t, str, "2025-01-15")
+}
+
+// Test GenerateArtifactKey
+
+func TestGenerateArtifactKey_Valid(t *testing.T) {
+ // Create root key
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().Add(365 * 24 * time.Hour).UTC(),
+ },
+ },
+ }
+
+ // Generate artifact key
+ ak, privPEM, pubPEM, signature, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+ assert.NotNil(t, ak)
+ assert.NotEmpty(t, privPEM)
+ assert.NotEmpty(t, pubPEM)
+ assert.NotEmpty(t, signature)
+
+ // Verify expiration
+ assert.True(t, ak.Metadata.ExpiresAt.After(time.Now()))
+ assert.True(t, ak.Metadata.ExpiresAt.Before(time.Now().Add(31*24*time.Hour)))
+}
+
+func TestGenerateArtifactKey_ExpiredRoot(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ // Create expired root key
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().Add(-2 * 365 * 24 * time.Hour).UTC(),
+ ExpiresAt: time.Now().Add(-1 * time.Hour).UTC(), // Expired
+ },
+ },
+ }
+
+ _, _, _, _, err = GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "expired")
+}
+
+func TestGenerateArtifactKey_NoExpiration(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ // Root key with no expiration
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Time{}, // No expiration
+ },
+ },
+ }
+
+ ak, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+ assert.NotNil(t, ak)
+}
+
+// Test ParseArtifactKey
+
+func TestParseArtifactKey_Valid(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ original, privPEM, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+
+ // Parse it back
+ parsed, err := ParseArtifactKey(privPEM)
+ require.NoError(t, err)
+
+ assert.Equal(t, original.Key, parsed.Key)
+ assert.Equal(t, original.Metadata.ID, parsed.Metadata.ID)
+}
+
+func TestParseArtifactKey_InvalidPEM(t *testing.T) {
+ _, err := ParseArtifactKey([]byte("invalid pem"))
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to parse")
+}
+
+func TestParseArtifactKey_WrongType(t *testing.T) {
+ pub, priv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ // Create a root key (wrong type)
+ rootKey := RootKey{
+ PrivateKey{
+ Key: priv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ privJSON, err := json.Marshal(rootKey.PrivateKey)
+ require.NoError(t, err)
+
+ privPEM := encodePrivateKey(privJSON, tagRootPrivate)
+
+ _, err = ParseArtifactKey(privPEM)
+ assert.Error(t, err)
+}
+
+// Test ParseArtifactPubKey
+
+func TestParseArtifactPubKey_Valid(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ original, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+
+ parsed, err := ParseArtifactPubKey(pubPEM)
+ require.NoError(t, err)
+
+ assert.Equal(t, original.Metadata.ID, parsed.Metadata.ID)
+}
+
+func TestParseArtifactPubKey_Invalid(t *testing.T) {
+ _, err := ParseArtifactPubKey([]byte("invalid"))
+ assert.Error(t, err)
+}
+
+// Test BundleArtifactKeys
+
+func TestBundleArtifactKeys_Single(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ _, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+
+ pubKey, err := ParseArtifactPubKey(pubPEM)
+ require.NoError(t, err)
+
+ bundle, signature, err := BundleArtifactKeys(rootKey, []PublicKey{pubKey})
+ require.NoError(t, err)
+ assert.NotEmpty(t, bundle)
+ assert.NotEmpty(t, signature)
+}
+
+func TestBundleArtifactKeys_Multiple(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Generate 3 artifact keys
+ var pubKeys []PublicKey
+ for i := 0; i < 3; i++ {
+ _, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+
+ pubKey, err := ParseArtifactPubKey(pubPEM)
+ require.NoError(t, err)
+ pubKeys = append(pubKeys, pubKey)
+ }
+
+ bundle, signature, err := BundleArtifactKeys(rootKey, pubKeys)
+ require.NoError(t, err)
+ assert.NotEmpty(t, bundle)
+ assert.NotEmpty(t, signature)
+
+ // Verify we can parse the bundle
+ parsed, err := parsePublicKeyBundle(bundle, tagArtifactPublic)
+ require.NoError(t, err)
+ assert.Len(t, parsed, 3)
+}
+
+func TestBundleArtifactKeys_Empty(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ _, _, err = BundleArtifactKeys(rootKey, []PublicKey{})
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "no keys")
+}
+
+// Test ValidateArtifactKeys
+
+func TestSingleValidateArtifactKey_Valid(t *testing.T) {
+ // Create root key
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Generate artifact key
+ _, _, pubPEM, sigData, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+
+ sig, _ := ParseSignature(sigData)
+
+ // Validate
+ validKeys, err := ValidateArtifactKeys(rootKeys, pubPEM, *sig, nil)
+ require.NoError(t, err)
+ assert.Len(t, validKeys, 1)
+}
+
+func TestValidateArtifactKeys_Valid(t *testing.T) {
+ // Create root key
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Generate artifact key
+ _, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+
+ pubKey, err := ParseArtifactPubKey(pubPEM)
+ require.NoError(t, err)
+
+ // Bundle and sign
+ bundle, sigData, err := BundleArtifactKeys(rootKey, []PublicKey{pubKey})
+ require.NoError(t, err)
+
+ sig, err := ParseSignature(sigData)
+ require.NoError(t, err)
+
+ // Validate
+ validKeys, err := ValidateArtifactKeys(rootKeys, bundle, *sig, nil)
+ require.NoError(t, err)
+ assert.Len(t, validKeys, 1)
+}
+
+func TestValidateArtifactKeys_FutureTimestamp(t *testing.T) {
+ rootPub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ sig := Signature{
+ Signature: make([]byte, 64),
+ Timestamp: time.Now().UTC().Add(10 * time.Minute),
+ KeyID: computeKeyID(rootPub),
+ Algorithm: "ed25519",
+ HashAlgo: "blake2s",
+ }
+
+ _, err = ValidateArtifactKeys(rootKeys, []byte("data"), sig, nil)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "in the future")
+}
+
+func TestValidateArtifactKeys_TooOld(t *testing.T) {
+ rootPub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ sig := Signature{
+ Signature: make([]byte, 64),
+ Timestamp: time.Now().UTC().Add(-20 * 365 * 24 * time.Hour),
+ KeyID: computeKeyID(rootPub),
+ Algorithm: "ed25519",
+ HashAlgo: "blake2s",
+ }
+
+ _, err = ValidateArtifactKeys(rootKeys, []byte("data"), sig, nil)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "too old")
+}
+
+func TestValidateArtifactKeys_InvalidSignature(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ _, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+
+ pubKey, err := ParseArtifactPubKey(pubPEM)
+ require.NoError(t, err)
+
+ bundle, _, err := BundleArtifactKeys(rootKey, []PublicKey{pubKey})
+ require.NoError(t, err)
+
+ // Create invalid signature
+ invalidSig := Signature{
+ Signature: make([]byte, 64),
+ Timestamp: time.Now().UTC(),
+ KeyID: computeKeyID(rootPub),
+ Algorithm: "ed25519",
+ HashAlgo: "blake2s",
+ }
+
+ _, err = ValidateArtifactKeys(rootKeys, bundle, invalidSig, nil)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to verify")
+}
+
+func TestValidateArtifactKeys_WithRevocation(t *testing.T) {
+ // Create root key
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Generate two artifact keys
+ _, _, pubPEM1, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+ pubKey1, err := ParseArtifactPubKey(pubPEM1)
+ require.NoError(t, err)
+
+ _, _, pubPEM2, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+ pubKey2, err := ParseArtifactPubKey(pubPEM2)
+ require.NoError(t, err)
+
+ // Bundle both keys
+ bundle, sigData, err := BundleArtifactKeys(rootKey, []PublicKey{pubKey1, pubKey2})
+ require.NoError(t, err)
+
+ sig, err := ParseSignature(sigData)
+ require.NoError(t, err)
+
+ // Create revocation list with first key revoked
+ revocationList := &RevocationList{
+ Revoked: map[KeyID]time.Time{
+ pubKey1.Metadata.ID: time.Now().UTC(),
+ },
+ LastUpdated: time.Now().UTC(),
+ }
+
+ // Validate - should only return second key
+ validKeys, err := ValidateArtifactKeys(rootKeys, bundle, *sig, revocationList)
+ require.NoError(t, err)
+ assert.Len(t, validKeys, 1)
+ assert.Equal(t, pubKey2.Metadata.ID, validKeys[0].Metadata.ID)
+}
+
+func TestValidateArtifactKeys_AllRevoked(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ _, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+ pubKey, err := ParseArtifactPubKey(pubPEM)
+ require.NoError(t, err)
+
+ bundle, sigData, err := BundleArtifactKeys(rootKey, []PublicKey{pubKey})
+ require.NoError(t, err)
+
+ sig, err := ParseSignature(sigData)
+ require.NoError(t, err)
+
+ // Revoke the key
+ revocationList := &RevocationList{
+ Revoked: map[KeyID]time.Time{
+ pubKey.Metadata.ID: time.Now().UTC(),
+ },
+ LastUpdated: time.Now().UTC(),
+ }
+
+ _, err = ValidateArtifactKeys(rootKeys, bundle, *sig, revocationList)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "revoked")
+}
+
+// Test ValidateArtifact
+
+func TestValidateArtifact_Valid(t *testing.T) {
+ // Create root key
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Generate artifact key
+ artifactKey, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+
+ // Sign some data
+ data := []byte("test artifact data")
+ sigData, err := SignData(*artifactKey, data)
+ require.NoError(t, err)
+
+ sig, err := ParseSignature(sigData)
+ require.NoError(t, err)
+
+ // Get public key for validation
+ artifactPubKey := PublicKey{
+ Key: artifactKey.Key.Public().(ed25519.PublicKey),
+ Metadata: artifactKey.Metadata,
+ }
+
+ // Validate
+ err = ValidateArtifact([]PublicKey{artifactPubKey}, data, *sig)
+ assert.NoError(t, err)
+}
+
+func TestValidateArtifact_FutureTimestamp(t *testing.T) {
+ pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ artifactPubKey := PublicKey{
+ Key: pub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: time.Now().UTC(),
+ },
+ }
+
+ sig := Signature{
+ Signature: make([]byte, 64),
+ Timestamp: time.Now().UTC().Add(10 * time.Minute),
+ KeyID: computeKeyID(pub),
+ Algorithm: "ed25519",
+ HashAlgo: "blake2s",
+ }
+
+ err = ValidateArtifact([]PublicKey{artifactPubKey}, []byte("data"), sig)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "in the future")
+}
+
+func TestValidateArtifact_TooOld(t *testing.T) {
+ pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ artifactPubKey := PublicKey{
+ Key: pub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: time.Now().UTC(),
+ },
+ }
+
+ sig := Signature{
+ Signature: make([]byte, 64),
+ Timestamp: time.Now().UTC().Add(-20 * 365 * 24 * time.Hour),
+ KeyID: computeKeyID(pub),
+ Algorithm: "ed25519",
+ HashAlgo: "blake2s",
+ }
+
+ err = ValidateArtifact([]PublicKey{artifactPubKey}, []byte("data"), sig)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "too old")
+}
+
+func TestValidateArtifact_ExpiredKey(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Generate artifact key with very short expiration
+ artifactKey, _, _, _, err := GenerateArtifactKey(rootKey, 1*time.Millisecond)
+ require.NoError(t, err)
+
+ // Wait for key to expire
+ time.Sleep(10 * time.Millisecond)
+
+ // Try to sign - should succeed but with old timestamp
+ data := []byte("test data")
+ sigData, err := SignData(*artifactKey, data)
+ require.Error(t, err) // Key is expired, so signing should fail
+ assert.Contains(t, err.Error(), "expired")
+ assert.Nil(t, sigData)
+}
+
+func TestValidateArtifact_WrongKey(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Generate two artifact keys
+ artifactKey1, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+
+ artifactKey2, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+
+ // Sign with key1
+ data := []byte("test data")
+ sigData, err := SignData(*artifactKey1, data)
+ require.NoError(t, err)
+
+ sig, err := ParseSignature(sigData)
+ require.NoError(t, err)
+
+ // Try to validate with key2 only
+ artifactPubKey2 := PublicKey{
+ Key: artifactKey2.Key.Public().(ed25519.PublicKey),
+ Metadata: artifactKey2.Metadata,
+ }
+
+ err = ValidateArtifact([]PublicKey{artifactPubKey2}, data, *sig)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "no signing Key found")
+}
+
+func TestValidateArtifact_TamperedData(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ artifactKey, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+
+ // Sign original data
+ originalData := []byte("original data")
+ sigData, err := SignData(*artifactKey, originalData)
+ require.NoError(t, err)
+
+ sig, err := ParseSignature(sigData)
+ require.NoError(t, err)
+
+ artifactPubKey := PublicKey{
+ Key: artifactKey.Key.Public().(ed25519.PublicKey),
+ Metadata: artifactKey.Metadata,
+ }
+
+ // Try to validate with tampered data
+ tamperedData := []byte("tampered data")
+ err = ValidateArtifact([]PublicKey{artifactPubKey}, tamperedData, *sig)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "verification failed")
+}
+
+func TestValidateArtifactKeys_TwoKeysOneExpired(t *testing.T) {
+ // Test ValidateArtifactKeys with a bundle containing two keys where one is expired
+ // Should return only the valid key
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Generate first key with very short expiration
+ _, _, expiredPubPEM, _, err := GenerateArtifactKey(rootKey, 1*time.Millisecond)
+ require.NoError(t, err)
+ expiredPubKey, err := ParseArtifactPubKey(expiredPubPEM)
+ require.NoError(t, err)
+
+ // Wait for first key to expire
+ time.Sleep(10 * time.Millisecond)
+
+ // Generate second key with normal expiration
+ _, _, validPubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+ validPubKey, err := ParseArtifactPubKey(validPubPEM)
+ require.NoError(t, err)
+
+ // Bundle both keys together
+ bundle, sigData, err := BundleArtifactKeys(rootKey, []PublicKey{expiredPubKey, validPubKey})
+ require.NoError(t, err)
+
+ sig, err := ParseSignature(sigData)
+ require.NoError(t, err)
+
+ // ValidateArtifactKeys should return only the valid key
+ validKeys, err := ValidateArtifactKeys(rootKeys, bundle, *sig, nil)
+ require.NoError(t, err)
+ assert.Len(t, validKeys, 1)
+ assert.Equal(t, validPubKey.Metadata.ID, validKeys[0].Metadata.ID)
+}
+
+func TestValidateArtifactKeys_TwoKeysBothExpired(t *testing.T) {
+ // Test ValidateArtifactKeys with a bundle containing two expired keys
+ // Should fail because no valid keys remain
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Generate first key with
+ _, _, pubPEM1, _, err := GenerateArtifactKey(rootKey, 24*time.Hour)
+ require.NoError(t, err)
+ pubKey1, err := ParseArtifactPubKey(pubPEM1)
+ require.NoError(t, err)
+
+ // Generate second key with very short expiration
+ _, _, pubPEM2, _, err := GenerateArtifactKey(rootKey, 1*time.Millisecond)
+ require.NoError(t, err)
+ pubKey2, err := ParseArtifactPubKey(pubPEM2)
+ require.NoError(t, err)
+
+ // Wait for expire
+ time.Sleep(10 * time.Millisecond)
+
+ bundle, sigData, err := BundleArtifactKeys(rootKey, []PublicKey{pubKey1, pubKey2})
+ require.NoError(t, err)
+
+ sig, err := ParseSignature(sigData)
+ require.NoError(t, err)
+
+ // ValidateArtifactKeys should fail because all keys are expired
+ keys, err := ValidateArtifactKeys(rootKeys, bundle, *sig, nil)
+ assert.NoError(t, err)
+ assert.Len(t, keys, 1)
+}
+
+// Test SignData
+
+func TestSignData_Valid(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ artifactKey, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+
+ data := []byte("test data to sign")
+ sigData, err := SignData(*artifactKey, data)
+ require.NoError(t, err)
+ assert.NotEmpty(t, sigData)
+
+ // Verify signature can be parsed
+ sig, err := ParseSignature(sigData)
+ require.NoError(t, err)
+ assert.NotEmpty(t, sig.Signature)
+ assert.Equal(t, "ed25519", sig.Algorithm)
+ assert.Equal(t, "blake2s", sig.HashAlgo)
+}
+
+func TestSignData_EmptyData(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ artifactKey, _, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+
+ _, err = SignData(*artifactKey, []byte{})
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "must be positive")
+}
+
+func TestSignData_ExpiredKey(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Generate key with very short expiration
+ artifactKey, _, _, _, err := GenerateArtifactKey(rootKey, 1*time.Millisecond)
+ require.NoError(t, err)
+
+ // Wait for expiration
+ time.Sleep(10 * time.Millisecond)
+
+ // Try to sign with expired key
+ _, err = SignData(*artifactKey, []byte("data"))
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "expired")
+}
+
+// Integration test
+
+func TestArtifact_FullWorkflow(t *testing.T) {
+ // Step 1: Create root key
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := &RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Step 2: Generate artifact key
+ artifactKey, _, pubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+
+ // Step 3: Create and validate key bundle
+ artifactPubKey, err := ParseArtifactPubKey(pubPEM)
+ require.NoError(t, err)
+
+ bundle, bundleSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey})
+ require.NoError(t, err)
+
+ sig, err := ParseSignature(bundleSig)
+ require.NoError(t, err)
+
+ validKeys, err := ValidateArtifactKeys(rootKeys, bundle, *sig, nil)
+ require.NoError(t, err)
+ assert.Len(t, validKeys, 1)
+
+ // Step 4: Sign artifact data
+ artifactData := []byte("This is my artifact data that needs to be signed")
+ artifactSig, err := SignData(*artifactKey, artifactData)
+ require.NoError(t, err)
+
+ // Step 5: Validate artifact
+ parsedSig, err := ParseSignature(artifactSig)
+ require.NoError(t, err)
+
+ err = ValidateArtifact(validKeys, artifactData, *parsedSig)
+ assert.NoError(t, err)
+}
+
+// Helper function for tests
+func encodePrivateKey(jsonData []byte, typeTag string) []byte {
+ return pem.EncodeToMemory(&pem.Block{
+ Type: typeTag,
+ Bytes: jsonData,
+ })
+}
diff --git a/client/internal/updatemanager/reposign/certs/root-pub.pem b/client/internal/updatemanager/reposign/certs/root-pub.pem
new file mode 100644
index 000000000..e7c2fd2c0
--- /dev/null
+++ b/client/internal/updatemanager/reposign/certs/root-pub.pem
@@ -0,0 +1,6 @@
+-----BEGIN ROOT PUBLIC KEY-----
+eyJLZXkiOiJoaGIxdGRDSEZNMFBuQWp1b2w2cXJ1QXRFbWFFSlg1QjFsZUNxWmpn
+V1pvPSIsIk1ldGFkYXRhIjp7ImlkIjoiOWE0OTg2NmI2MzE2MjNiNCIsImNyZWF0
+ZWRfYXQiOiIyMDI1LTExLTI0VDE3OjE1OjI4LjYyNzE3MzE3MVoiLCJleHBpcmVz
+X2F0IjoiMjAzNS0xMS0yMlQxNzoxNToyOC42MjcxNzMxNzFaIn19
+-----END ROOT PUBLIC KEY-----
diff --git a/client/internal/updatemanager/reposign/certsdev/root-pub.pem b/client/internal/updatemanager/reposign/certsdev/root-pub.pem
new file mode 100644
index 000000000..f7145477b
--- /dev/null
+++ b/client/internal/updatemanager/reposign/certsdev/root-pub.pem
@@ -0,0 +1,6 @@
+-----BEGIN ROOT PUBLIC KEY-----
+eyJLZXkiOiJyTDByVTN2MEFOZUNmbDZraitiUUd3TE1waU5CaUJLdVBWSnZtQzgr
+ZS84PSIsIk1ldGFkYXRhIjp7ImlkIjoiMTBkNjQyZTY2N2FmMDNkNCIsImNyZWF0
+ZWRfYXQiOiIyMDI1LTExLTIwVDE3OjI5OjI5LjE4MDk0NjMxNloiLCJleHBpcmVz
+X2F0IjoiMjAyNi0xMS0yMFQxNzoyOToyOS4xODA5NDYzMTZaIn19
+-----END ROOT PUBLIC KEY-----
diff --git a/client/internal/updatemanager/reposign/doc.go b/client/internal/updatemanager/reposign/doc.go
new file mode 100644
index 000000000..660b9d11d
--- /dev/null
+++ b/client/internal/updatemanager/reposign/doc.go
@@ -0,0 +1,174 @@
+// Package reposign implements a cryptographic signing and verification system
+// for NetBird software update artifacts. It provides a hierarchical key
+// management system with support for key rotation, revocation, and secure
+// artifact distribution.
+//
+// # Architecture
+//
+// The package uses a two-tier key hierarchy:
+//
+// - Root Keys: Long-lived keys that sign artifact keys. These are embedded
+// in the client binary and establish the root of trust. Root keys should
+// be kept offline and highly secured.
+//
+// - Artifact Keys: Short-lived keys that sign release artifacts (binaries,
+// packages, etc.). These are rotated regularly and can be revoked if
+// compromised. Artifact keys are signed by root keys and distributed via
+// a public repository.
+//
+// This separation allows for operational flexibility: artifact keys can be
+// rotated frequently without requiring client updates, while root keys remain
+// stable and embedded in the software.
+//
+// # Cryptographic Primitives
+//
+// The package uses strong, modern cryptographic algorithms:
+// - Ed25519: Fast, secure digital signatures (no timing attacks)
+// - BLAKE2s-256: Fast cryptographic hash for artifacts
+// - SHA-256: Key ID generation
+// - JSON: Structured key and signature serialization
+// - PEM: Standard key encoding format
+//
+// # Security Features
+//
+// Timestamp Binding:
+// - All signatures include cryptographically-bound timestamps
+// - Prevents replay attacks and enforces signature freshness
+// - Clock skew tolerance: 5 minutes
+//
+// Key Expiration:
+// - All keys have expiration times
+// - Expired keys are automatically rejected
+// - Signing with an expired key fails immediately
+//
+// Key Revocation:
+// - Compromised keys can be revoked via a signed revocation list
+// - Revocation list is checked during artifact validation
+// - Revoked keys are filtered out before artifact verification
+//
+// # File Structure
+//
+// The package expects the following file layout in the key repository:
+//
+// signrepo/
+// artifact-key-pub.pem # Bundle of artifact public keys
+// artifact-key-pub.pem.sig # Root signature of the bundle
+// revocation-list.json # List of revoked key IDs
+// revocation-list.json.sig # Root signature of revocation list
+//
+// And in the artifacts repository:
+//
+// releases/
+// v0.28.0/
+// netbird-linux-amd64
+// netbird-linux-amd64.sig # Artifact signature
+// netbird-darwin-amd64
+// netbird-darwin-amd64.sig
+// ...
+//
+// # Embedded Root Keys
+//
+// Root public keys are embedded in the client binary at compile time:
+// - Production keys: certs/ directory
+// - Development keys: certsdev/ directory
+//
+// The build tag determines which keys are embedded:
+// - Production builds: //go:build !devartifactsign
+// - Development builds: //go:build devartifactsign
+//
+// This ensures that development artifacts cannot be verified using production
+// keys and vice versa.
+//
+// # Key Rotation Strategies
+//
+// Root Key Rotation:
+//
+// Root keys can be rotated without breaking existing clients by leveraging
+// the multi-key verification system. The loadEmbeddedPublicKeys function
+// reads ALL files from the certs/ directory and accepts signatures from ANY
+// of the embedded root keys.
+//
+// To rotate root keys:
+//
+// 1. Generate a new root key pair:
+// newRootKey, privPEM, pubPEM, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
+//
+// 2. Add the new public key to the certs/ directory as a new file:
+// certs/
+// root-pub-2024.pem # Old key (keep this!)
+// root-pub-2025.pem # New key (add this)
+//
+// 3. Build new client versions with both keys embedded. The verification
+// will accept signatures from either key.
+//
+// 4. Start signing new artifact keys with the new root key. Old clients
+// with only the old root key will reject these, but new clients with
+// both keys will accept them.
+//
+// Each file in certs/ can contain a single key or a bundle of keys (multiple
+// PEM blocks). The system will parse all keys from all files and use them
+// for verification. This provides maximum flexibility for key management.
+//
+// Important: Never remove all old root keys at once. Always maintain at least
+// one overlapping key between releases to ensure smooth transitions.
+//
+// Artifact Key Rotation:
+//
+// Artifact keys should be rotated regularly (e.g., every 90 days) using the
+// bundling mechanism. The BundleArtifactKeys function allows multiple artifact
+// keys to be bundled together in a single signed package, and ValidateArtifact
+// will accept signatures from ANY key in the bundle.
+//
+// To rotate artifact keys smoothly:
+//
+// 1. Generate a new artifact key while keeping the old one:
+// newKey, newPrivPEM, newPubPEM, newSig, err := GenerateArtifactKey(rootKey, 90 * 24 * time.Hour)
+// // Keep oldPubPEM and oldKey available
+//
+// 2. Create a bundle containing both old and new public keys
+//
+// 3. Upload the bundle and its signature to the key repository:
+// signrepo/artifact-key-pub.pem # Contains both keys
+// signrepo/artifact-key-pub.pem.sig # Root signature
+//
+// 4. Start signing new releases with the NEW key, but keep the bundle
+// unchanged. Clients will download the bundle (containing both keys)
+// and accept signatures from either key.
+//
+// Key bundle validation workflow:
+// 1. Client downloads artifact-key-pub.pem and artifact-key-pub.pem.sig
+// 2. ValidateArtifactKeys verifies the bundle signature with ANY embedded root key
+// 3. ValidateArtifactKeys parses all public keys from the bundle
+// 4. ValidateArtifactKeys filters out expired or revoked keys
+// 5. When verifying an artifact, ValidateArtifact tries each key until one succeeds
+//
+// This multi-key acceptance model enables overlapping validity periods and
+// smooth transitions without client update requirements.
+//
+// # Best Practices
+//
+// Root Key Management:
+// - Generate root keys offline on an air-gapped machine
+// - Store root private keys in hardware security modules (HSM) if possible
+// - Use separate root keys for production and development
+// - Rotate root keys infrequently (e.g., every 5-10 years)
+// - Plan for root key rotation: embed multiple root public keys
+//
+// Artifact Key Management:
+// - Rotate artifact keys regularly (e.g., every 90 days)
+// - Use separate artifact keys for different release channels if needed
+// - Revoke keys immediately upon suspected compromise
+// - Bundle multiple artifact keys to enable smooth rotation
+//
+// Signing Process:
+// - Sign artifacts in a secure CI/CD environment
+// - Never commit private keys to version control
+// - Use environment variables or secret management for keys
+// - Verify signatures immediately after signing
+//
+// Distribution:
+// - Serve keys and revocation lists from a reliable CDN
+// - Use HTTPS for all key and artifact downloads
+// - Monitor download failures and signature verification failures
+// - Keep revocation list up to date
+package reposign
diff --git a/client/internal/updatemanager/reposign/embed_dev.go b/client/internal/updatemanager/reposign/embed_dev.go
new file mode 100644
index 000000000..ef8f77373
--- /dev/null
+++ b/client/internal/updatemanager/reposign/embed_dev.go
@@ -0,0 +1,10 @@
+//go:build devartifactsign
+
+package reposign
+
+import "embed"
+
+//go:embed certsdev
+var embeddedCerts embed.FS
+
+const embeddedCertsDir = "certsdev"
diff --git a/client/internal/updatemanager/reposign/embed_prod.go b/client/internal/updatemanager/reposign/embed_prod.go
new file mode 100644
index 000000000..91530e5f4
--- /dev/null
+++ b/client/internal/updatemanager/reposign/embed_prod.go
@@ -0,0 +1,10 @@
+//go:build !devartifactsign
+
+package reposign
+
+import "embed"
+
+//go:embed certs
+var embeddedCerts embed.FS
+
+const embeddedCertsDir = "certs"
diff --git a/client/internal/updatemanager/reposign/key.go b/client/internal/updatemanager/reposign/key.go
new file mode 100644
index 000000000..bedfef70d
--- /dev/null
+++ b/client/internal/updatemanager/reposign/key.go
@@ -0,0 +1,171 @@
+package reposign
+
+import (
+ "crypto/ed25519"
+ "crypto/sha256"
+ "encoding/hex"
+ "encoding/json"
+ "encoding/pem"
+ "errors"
+ "fmt"
+ "time"
+)
+
+const (
+ maxClockSkew = 5 * time.Minute
+)
+
+// KeyID is a unique identifier for a Key (first 8 bytes of SHA-256 of public Key)
+type KeyID [8]byte
+
+// computeKeyID generates a unique ID from a public Key
+func computeKeyID(pub ed25519.PublicKey) KeyID {
+ h := sha256.Sum256(pub)
+ var id KeyID
+ copy(id[:], h[:8])
+ return id
+}
+
+// MarshalJSON implements json.Marshaler for KeyID
+func (k KeyID) MarshalJSON() ([]byte, error) {
+ return json.Marshal(k.String())
+}
+
+// UnmarshalJSON implements json.Unmarshaler for KeyID
+func (k *KeyID) UnmarshalJSON(data []byte) error {
+ var s string
+ if err := json.Unmarshal(data, &s); err != nil {
+ return err
+ }
+
+ parsed, err := ParseKeyID(s)
+ if err != nil {
+ return err
+ }
+
+ *k = parsed
+ return nil
+}
+
+// ParseKeyID parses a hex string (16 hex chars = 8 bytes) into a KeyID.
+func ParseKeyID(s string) (KeyID, error) {
+ var id KeyID
+ if len(s) != 16 {
+ return id, fmt.Errorf("invalid KeyID length: got %d, want 16 hex chars (8 bytes)", len(s))
+ }
+
+ b, err := hex.DecodeString(s)
+ if err != nil {
+ return id, fmt.Errorf("failed to decode KeyID: %w", err)
+ }
+
+ copy(id[:], b)
+ return id, nil
+}
+
+func (k KeyID) String() string {
+ return fmt.Sprintf("%x", k[:])
+}
+
+// KeyMetadata contains versioning and lifecycle information for a Key
+type KeyMetadata struct {
+ ID KeyID `json:"id"`
+ CreatedAt time.Time `json:"created_at"`
+ ExpiresAt time.Time `json:"expires_at,omitempty"` // Optional expiration
+}
+
+// PublicKey wraps a public Key with its Metadata
+type PublicKey struct {
+ Key ed25519.PublicKey
+ Metadata KeyMetadata
+}
+
+func parsePublicKeyBundle(bundle []byte, typeTag string) ([]PublicKey, error) {
+ var keys []PublicKey
+ for len(bundle) > 0 {
+ keyInfo, rest, err := parsePublicKey(bundle, typeTag)
+ if err != nil {
+ return nil, err
+ }
+ keys = append(keys, keyInfo)
+ bundle = rest
+ }
+ if len(keys) == 0 {
+ return nil, errors.New("no keys found in bundle")
+ }
+ return keys, nil
+}
+
+func parsePublicKey(data []byte, typeTag string) (PublicKey, []byte, error) {
+ b, rest := pem.Decode(data)
+ if b == nil {
+ return PublicKey{}, nil, errors.New("failed to decode PEM data")
+ }
+ if b.Type != typeTag {
+ return PublicKey{}, nil, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag)
+ }
+
+ // Unmarshal JSON-embedded format
+ var pub PublicKey
+ if err := json.Unmarshal(b.Bytes, &pub); err != nil {
+ return PublicKey{}, nil, fmt.Errorf("failed to unmarshal public key: %w", err)
+ }
+
+ // Validate key length
+ if len(pub.Key) != ed25519.PublicKeySize {
+ return PublicKey{}, nil, fmt.Errorf("incorrect Ed25519 public key size: expected %d, got %d",
+ ed25519.PublicKeySize, len(pub.Key))
+ }
+
+ // Always recompute ID to ensure integrity
+ pub.Metadata.ID = computeKeyID(pub.Key)
+
+ return pub, rest, nil
+}
+
+type PrivateKey struct {
+ Key ed25519.PrivateKey
+ Metadata KeyMetadata
+}
+
+func parsePrivateKey(data []byte, typeTag string) (PrivateKey, error) {
+ b, rest := pem.Decode(data)
+ if b == nil {
+ return PrivateKey{}, errors.New("failed to decode PEM data")
+ }
+ if len(rest) > 0 {
+ return PrivateKey{}, errors.New("trailing PEM data")
+ }
+ if b.Type != typeTag {
+ return PrivateKey{}, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag)
+ }
+
+ // Unmarshal JSON-embedded format
+ var pk PrivateKey
+ if err := json.Unmarshal(b.Bytes, &pk); err != nil {
+ return PrivateKey{}, fmt.Errorf("failed to unmarshal private key: %w", err)
+ }
+
+ // Validate key length
+ if len(pk.Key) != ed25519.PrivateKeySize {
+ return PrivateKey{}, fmt.Errorf("incorrect Ed25519 private key size: expected %d, got %d",
+ ed25519.PrivateKeySize, len(pk.Key))
+ }
+
+ return pk, nil
+}
+
+func verifyAny(publicRootKeys []PublicKey, msg, sig []byte) bool {
+ // Verify with root keys
+ var rootKeys []ed25519.PublicKey
+ for _, r := range publicRootKeys {
+ rootKeys = append(rootKeys, r.Key)
+ }
+
+ for _, k := range rootKeys {
+ if ed25519.Verify(k, msg, sig) {
+ return true
+ }
+ }
+ return false
+}
diff --git a/client/internal/updatemanager/reposign/key_test.go b/client/internal/updatemanager/reposign/key_test.go
new file mode 100644
index 000000000..f8e1676fb
--- /dev/null
+++ b/client/internal/updatemanager/reposign/key_test.go
@@ -0,0 +1,636 @@
+package reposign
+
+import (
+ "crypto/ed25519"
+ "crypto/rand"
+ "crypto/sha256"
+ "encoding/json"
+ "encoding/pem"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// Test KeyID functions
+
+func TestComputeKeyID(t *testing.T) {
+ pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ keyID := computeKeyID(pub)
+
+ // Verify it's the first 8 bytes of SHA-256
+ h := sha256.Sum256(pub)
+ expectedID := KeyID{}
+ copy(expectedID[:], h[:8])
+
+ assert.Equal(t, expectedID, keyID)
+}
+
+func TestComputeKeyID_Deterministic(t *testing.T) {
+ pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ // Computing KeyID multiple times should give the same result
+ keyID1 := computeKeyID(pub)
+ keyID2 := computeKeyID(pub)
+
+ assert.Equal(t, keyID1, keyID2)
+}
+
+func TestComputeKeyID_DifferentKeys(t *testing.T) {
+ pub1, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ pub2, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ keyID1 := computeKeyID(pub1)
+ keyID2 := computeKeyID(pub2)
+
+ // Different keys should produce different IDs
+ assert.NotEqual(t, keyID1, keyID2)
+}
+
+func TestParseKeyID_Valid(t *testing.T) {
+ hexStr := "0123456789abcdef"
+
+ keyID, err := ParseKeyID(hexStr)
+ require.NoError(t, err)
+
+ expected := KeyID{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}
+ assert.Equal(t, expected, keyID)
+}
+
+func TestParseKeyID_InvalidLength(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ }{
+ {"too short", "01234567"},
+ {"too long", "0123456789abcdef00"},
+ {"empty", ""},
+ {"odd length", "0123456789abcde"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ _, err := ParseKeyID(tt.input)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "invalid KeyID length")
+ })
+ }
+}
+
+func TestParseKeyID_InvalidHex(t *testing.T) {
+ invalidHex := "0123456789abcxyz" // 'xyz' are not valid hex
+
+ _, err := ParseKeyID(invalidHex)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to decode KeyID")
+}
+
+func TestKeyID_String(t *testing.T) {
+ keyID := KeyID{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}
+
+ str := keyID.String()
+ assert.Equal(t, "0123456789abcdef", str)
+}
+
+func TestKeyID_RoundTrip(t *testing.T) {
+ original := "fedcba9876543210"
+
+ keyID, err := ParseKeyID(original)
+ require.NoError(t, err)
+
+ result := keyID.String()
+ assert.Equal(t, original, result)
+}
+
+func TestKeyID_ZeroValue(t *testing.T) {
+ keyID := KeyID{}
+ str := keyID.String()
+ assert.Equal(t, "0000000000000000", str)
+}
+
+// Test KeyMetadata
+
+func TestKeyMetadata_JSONMarshaling(t *testing.T) {
+ pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ metadata := KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC),
+ ExpiresAt: time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC),
+ }
+
+ jsonData, err := json.Marshal(metadata)
+ require.NoError(t, err)
+
+ var decoded KeyMetadata
+ err = json.Unmarshal(jsonData, &decoded)
+ require.NoError(t, err)
+
+ assert.Equal(t, metadata.ID, decoded.ID)
+ assert.Equal(t, metadata.CreatedAt.Unix(), decoded.CreatedAt.Unix())
+ assert.Equal(t, metadata.ExpiresAt.Unix(), decoded.ExpiresAt.Unix())
+}
+
+func TestKeyMetadata_NoExpiration(t *testing.T) {
+ pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ metadata := KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC),
+ ExpiresAt: time.Time{}, // Zero value = no expiration
+ }
+
+ jsonData, err := json.Marshal(metadata)
+ require.NoError(t, err)
+
+ var decoded KeyMetadata
+ err = json.Unmarshal(jsonData, &decoded)
+ require.NoError(t, err)
+
+ assert.True(t, decoded.ExpiresAt.IsZero())
+}
+
+// Test PublicKey
+
+func TestPublicKey_JSONMarshaling(t *testing.T) {
+ pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ pubKey := PublicKey{
+ Key: pub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().Add(365 * 24 * time.Hour).UTC(),
+ },
+ }
+
+ jsonData, err := json.Marshal(pubKey)
+ require.NoError(t, err)
+
+ var decoded PublicKey
+ err = json.Unmarshal(jsonData, &decoded)
+ require.NoError(t, err)
+
+ assert.Equal(t, pubKey.Key, decoded.Key)
+ assert.Equal(t, pubKey.Metadata.ID, decoded.Metadata.ID)
+}
+
+// Test parsePublicKey
+
+func TestParsePublicKey_Valid(t *testing.T) {
+ pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ metadata := KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().Add(365 * 24 * time.Hour).UTC(),
+ }
+
+ pubKey := PublicKey{
+ Key: pub,
+ Metadata: metadata,
+ }
+
+ // Marshal to JSON
+ jsonData, err := json.Marshal(pubKey)
+ require.NoError(t, err)
+
+ // Encode to PEM
+ pemData := pem.EncodeToMemory(&pem.Block{
+ Type: tagRootPublic,
+ Bytes: jsonData,
+ })
+
+ // Parse it back
+ parsed, rest, err := parsePublicKey(pemData, tagRootPublic)
+ require.NoError(t, err)
+ assert.Empty(t, rest)
+ assert.Equal(t, pub, parsed.Key)
+ assert.Equal(t, metadata.ID, parsed.Metadata.ID)
+}
+
+func TestParsePublicKey_InvalidPEM(t *testing.T) {
+ invalidPEM := []byte("not a PEM")
+
+ _, _, err := parsePublicKey(invalidPEM, tagRootPublic)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to decode PEM")
+}
+
+func TestParsePublicKey_WrongType(t *testing.T) {
+ pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ pubKey := PublicKey{
+ Key: pub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: time.Now().UTC(),
+ },
+ }
+
+ jsonData, err := json.Marshal(pubKey)
+ require.NoError(t, err)
+
+ // Encode with wrong type
+ pemData := pem.EncodeToMemory(&pem.Block{
+ Type: "WRONG TYPE",
+ Bytes: jsonData,
+ })
+
+ _, _, err = parsePublicKey(pemData, tagRootPublic)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "PEM type")
+}
+
+func TestParsePublicKey_InvalidJSON(t *testing.T) {
+ pemData := pem.EncodeToMemory(&pem.Block{
+ Type: tagRootPublic,
+ Bytes: []byte("invalid json"),
+ })
+
+ _, _, err := parsePublicKey(pemData, tagRootPublic)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to unmarshal")
+}
+
+func TestParsePublicKey_InvalidKeySize(t *testing.T) {
+ // Create a public key with wrong size
+ pubKey := PublicKey{
+ Key: []byte{0x01, 0x02, 0x03}, // Too short
+ Metadata: KeyMetadata{
+ ID: KeyID{},
+ CreatedAt: time.Now().UTC(),
+ },
+ }
+
+ jsonData, err := json.Marshal(pubKey)
+ require.NoError(t, err)
+
+ pemData := pem.EncodeToMemory(&pem.Block{
+ Type: tagRootPublic,
+ Bytes: jsonData,
+ })
+
+ _, _, err = parsePublicKey(pemData, tagRootPublic)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "incorrect Ed25519 public key size")
+}
+
+func TestParsePublicKey_IDRecomputation(t *testing.T) {
+ pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ // Create a public key with WRONG ID
+ wrongID := KeyID{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
+ pubKey := PublicKey{
+ Key: pub,
+ Metadata: KeyMetadata{
+ ID: wrongID,
+ CreatedAt: time.Now().UTC(),
+ },
+ }
+
+ jsonData, err := json.Marshal(pubKey)
+ require.NoError(t, err)
+
+ pemData := pem.EncodeToMemory(&pem.Block{
+ Type: tagRootPublic,
+ Bytes: jsonData,
+ })
+
+ // Parse should recompute the correct ID
+ parsed, _, err := parsePublicKey(pemData, tagRootPublic)
+ require.NoError(t, err)
+
+ correctID := computeKeyID(pub)
+ assert.Equal(t, correctID, parsed.Metadata.ID)
+ assert.NotEqual(t, wrongID, parsed.Metadata.ID)
+}
+
+// Test parsePublicKeyBundle
+
+func TestParsePublicKeyBundle_Single(t *testing.T) {
+ pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ pubKey := PublicKey{
+ Key: pub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: time.Now().UTC(),
+ },
+ }
+
+ jsonData, err := json.Marshal(pubKey)
+ require.NoError(t, err)
+
+ pemData := pem.EncodeToMemory(&pem.Block{
+ Type: tagRootPublic,
+ Bytes: jsonData,
+ })
+
+ keys, err := parsePublicKeyBundle(pemData, tagRootPublic)
+ require.NoError(t, err)
+ assert.Len(t, keys, 1)
+ assert.Equal(t, pub, keys[0].Key)
+}
+
+func TestParsePublicKeyBundle_Multiple(t *testing.T) {
+ var bundle []byte
+
+ // Create 3 keys
+ for i := 0; i < 3; i++ {
+ pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ pubKey := PublicKey{
+ Key: pub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: time.Now().UTC(),
+ },
+ }
+
+ jsonData, err := json.Marshal(pubKey)
+ require.NoError(t, err)
+
+ pemData := pem.EncodeToMemory(&pem.Block{
+ Type: tagRootPublic,
+ Bytes: jsonData,
+ })
+
+ bundle = append(bundle, pemData...)
+ }
+
+ keys, err := parsePublicKeyBundle(bundle, tagRootPublic)
+ require.NoError(t, err)
+ assert.Len(t, keys, 3)
+}
+
+func TestParsePublicKeyBundle_Empty(t *testing.T) {
+ _, err := parsePublicKeyBundle([]byte{}, tagRootPublic)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "no keys found")
+}
+
+func TestParsePublicKeyBundle_Invalid(t *testing.T) {
+ _, err := parsePublicKeyBundle([]byte("invalid data"), tagRootPublic)
+ assert.Error(t, err)
+}
+
+// Test PrivateKey
+
+func TestPrivateKey_JSONMarshaling(t *testing.T) {
+ pub, priv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ privKey := PrivateKey{
+ Key: priv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: time.Now().UTC(),
+ },
+ }
+
+ jsonData, err := json.Marshal(privKey)
+ require.NoError(t, err)
+
+ var decoded PrivateKey
+ err = json.Unmarshal(jsonData, &decoded)
+ require.NoError(t, err)
+
+ assert.Equal(t, privKey.Key, decoded.Key)
+ assert.Equal(t, privKey.Metadata.ID, decoded.Metadata.ID)
+}
+
+// Test parsePrivateKey
+
+func TestParsePrivateKey_Valid(t *testing.T) {
+ pub, priv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ privKey := PrivateKey{
+ Key: priv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: time.Now().UTC(),
+ },
+ }
+
+ jsonData, err := json.Marshal(privKey)
+ require.NoError(t, err)
+
+ pemData := pem.EncodeToMemory(&pem.Block{
+ Type: tagRootPrivate,
+ Bytes: jsonData,
+ })
+
+ parsed, err := parsePrivateKey(pemData, tagRootPrivate)
+ require.NoError(t, err)
+ assert.Equal(t, priv, parsed.Key)
+}
+
+func TestParsePrivateKey_InvalidPEM(t *testing.T) {
+ _, err := parsePrivateKey([]byte("not a PEM"), tagRootPrivate)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to decode PEM")
+}
+
+func TestParsePrivateKey_TrailingData(t *testing.T) {
+ pub, priv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ privKey := PrivateKey{
+ Key: priv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: time.Now().UTC(),
+ },
+ }
+
+ jsonData, err := json.Marshal(privKey)
+ require.NoError(t, err)
+
+ pemData := pem.EncodeToMemory(&pem.Block{
+ Type: tagRootPrivate,
+ Bytes: jsonData,
+ })
+
+ // Add trailing data
+ pemData = append(pemData, []byte("extra data")...)
+
+ _, err = parsePrivateKey(pemData, tagRootPrivate)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "trailing PEM data")
+}
+
+func TestParsePrivateKey_WrongType(t *testing.T) {
+ pub, priv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ privKey := PrivateKey{
+ Key: priv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: time.Now().UTC(),
+ },
+ }
+
+ jsonData, err := json.Marshal(privKey)
+ require.NoError(t, err)
+
+ pemData := pem.EncodeToMemory(&pem.Block{
+ Type: "WRONG TYPE",
+ Bytes: jsonData,
+ })
+
+ _, err = parsePrivateKey(pemData, tagRootPrivate)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "PEM type")
+}
+
+func TestParsePrivateKey_InvalidKeySize(t *testing.T) {
+ privKey := PrivateKey{
+ Key: []byte{0x01, 0x02, 0x03}, // Too short
+ Metadata: KeyMetadata{
+ ID: KeyID{},
+ CreatedAt: time.Now().UTC(),
+ },
+ }
+
+ jsonData, err := json.Marshal(privKey)
+ require.NoError(t, err)
+
+ pemData := pem.EncodeToMemory(&pem.Block{
+ Type: tagRootPrivate,
+ Bytes: jsonData,
+ })
+
+ _, err = parsePrivateKey(pemData, tagRootPrivate)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "incorrect Ed25519 private key size")
+}
+
+// Test verifyAny
+
+func TestVerifyAny_ValidSignature(t *testing.T) {
+ pub, priv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ message := []byte("test message")
+ signature := ed25519.Sign(priv, message)
+
+ rootKeys := []PublicKey{
+ {
+ Key: pub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ result := verifyAny(rootKeys, message, signature)
+ assert.True(t, result)
+}
+
+func TestVerifyAny_InvalidSignature(t *testing.T) {
+ pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ message := []byte("test message")
+ invalidSignature := make([]byte, ed25519.SignatureSize)
+
+ rootKeys := []PublicKey{
+ {
+ Key: pub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ result := verifyAny(rootKeys, message, invalidSignature)
+ assert.False(t, result)
+}
+
+func TestVerifyAny_MultipleKeys(t *testing.T) {
+ // Create 3 key pairs
+ pub1, priv1, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ pub2, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ pub3, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ message := []byte("test message")
+ signature := ed25519.Sign(priv1, message)
+
+ rootKeys := []PublicKey{
+ {Key: pub2, Metadata: KeyMetadata{ID: computeKeyID(pub2)}},
+ {Key: pub1, Metadata: KeyMetadata{ID: computeKeyID(pub1)}}, // Correct key in middle
+ {Key: pub3, Metadata: KeyMetadata{ID: computeKeyID(pub3)}},
+ }
+
+ result := verifyAny(rootKeys, message, signature)
+ assert.True(t, result)
+}
+
+func TestVerifyAny_NoMatchingKey(t *testing.T) {
+ _, priv1, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ pub2, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ message := []byte("test message")
+ signature := ed25519.Sign(priv1, message)
+
+ // Only include pub2, not pub1
+ rootKeys := []PublicKey{
+ {Key: pub2, Metadata: KeyMetadata{ID: computeKeyID(pub2)}},
+ }
+
+ result := verifyAny(rootKeys, message, signature)
+ assert.False(t, result)
+}
+
+func TestVerifyAny_EmptyKeys(t *testing.T) {
+ message := []byte("test message")
+ signature := make([]byte, ed25519.SignatureSize)
+
+ result := verifyAny([]PublicKey{}, message, signature)
+ assert.False(t, result)
+}
+
+func TestVerifyAny_TamperedMessage(t *testing.T) {
+ pub, priv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ message := []byte("test message")
+ signature := ed25519.Sign(priv, message)
+
+ rootKeys := []PublicKey{
+ {Key: pub, Metadata: KeyMetadata{ID: computeKeyID(pub)}},
+ }
+
+ // Verify with different message
+ tamperedMessage := []byte("different message")
+ result := verifyAny(rootKeys, tamperedMessage, signature)
+ assert.False(t, result)
+}
diff --git a/client/internal/updatemanager/reposign/revocation.go b/client/internal/updatemanager/reposign/revocation.go
new file mode 100644
index 000000000..e679e212f
--- /dev/null
+++ b/client/internal/updatemanager/reposign/revocation.go
@@ -0,0 +1,229 @@
+package reposign
+
+import (
+ "crypto/ed25519"
+ "encoding/binary"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+)
+
+const (
+ maxRevocationSignatureAge = 10 * 365 * 24 * time.Hour
+ defaultRevocationListExpiration = 365 * 24 * time.Hour
+)
+
+type RevocationList struct {
+ Revoked map[KeyID]time.Time `json:"revoked"` // KeyID -> revocation time
+ LastUpdated time.Time `json:"last_updated"` // When the list was last modified
+ ExpiresAt time.Time `json:"expires_at"` // When the list expires
+}
+
+func (rl RevocationList) MarshalJSON() ([]byte, error) {
+ // Convert map[KeyID]time.Time to map[string]time.Time
+ strMap := make(map[string]time.Time, len(rl.Revoked))
+ for k, v := range rl.Revoked {
+ strMap[k.String()] = v
+ }
+
+ return json.Marshal(map[string]interface{}{
+ "revoked": strMap,
+ "last_updated": rl.LastUpdated,
+ "expires_at": rl.ExpiresAt,
+ })
+}
+
+func (rl *RevocationList) UnmarshalJSON(data []byte) error {
+ var temp struct {
+ Revoked map[string]time.Time `json:"revoked"`
+ LastUpdated time.Time `json:"last_updated"`
+ ExpiresAt time.Time `json:"expires_at"`
+ Version int `json:"version"`
+ }
+
+ if err := json.Unmarshal(data, &temp); err != nil {
+ return err
+ }
+
+ // Convert map[string]time.Time back to map[KeyID]time.Time
+ rl.Revoked = make(map[KeyID]time.Time, len(temp.Revoked))
+ for k, v := range temp.Revoked {
+ kid, err := ParseKeyID(k)
+ if err != nil {
+ return fmt.Errorf("failed to parse KeyID %q: %w", k, err)
+ }
+ rl.Revoked[kid] = v
+ }
+
+ rl.LastUpdated = temp.LastUpdated
+ rl.ExpiresAt = temp.ExpiresAt
+
+ return nil
+}
+
+func ParseRevocationList(data []byte) (*RevocationList, error) {
+ var rl RevocationList
+ if err := json.Unmarshal(data, &rl); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal revocation list: %w", err)
+ }
+
+ // Initialize the map if it's nil (in case of empty JSON object)
+ if rl.Revoked == nil {
+ rl.Revoked = make(map[KeyID]time.Time)
+ }
+
+ if rl.LastUpdated.IsZero() {
+ return nil, fmt.Errorf("revocation list missing last_updated timestamp")
+ }
+
+ if rl.ExpiresAt.IsZero() {
+ return nil, fmt.Errorf("revocation list missing expires_at timestamp")
+ }
+
+ return &rl, nil
+}
+
+func ValidateRevocationList(publicRootKeys []PublicKey, data []byte, signature Signature) (*RevocationList, error) {
+ revoList, err := ParseRevocationList(data)
+ if err != nil {
+ log.Debugf("failed to parse revocation list: %s", err)
+ return nil, err
+ }
+
+ now := time.Now().UTC()
+
+ // Validate signature timestamp
+ if signature.Timestamp.After(now.Add(maxClockSkew)) {
+ err := fmt.Errorf("revocation signature timestamp is in the future: %v", signature.Timestamp)
+ log.Debugf("revocation list signature error: %v", err)
+ return nil, err
+ }
+
+ if now.Sub(signature.Timestamp) > maxRevocationSignatureAge {
+ err := fmt.Errorf("revocation list signature is too old: %v (created %v)",
+ now.Sub(signature.Timestamp), signature.Timestamp)
+ log.Debugf("revocation list signature error: %v", err)
+ return nil, err
+ }
+
+ // Ensure LastUpdated is not in the future (with clock skew tolerance)
+ if revoList.LastUpdated.After(now.Add(maxClockSkew)) {
+ err := fmt.Errorf("revocation list LastUpdated is in the future: %v", revoList.LastUpdated)
+ log.Errorf("rejecting future-dated revocation list: %v", err)
+ return nil, err
+ }
+
+ // Check if the revocation list has expired
+ if now.After(revoList.ExpiresAt) {
+ err := fmt.Errorf("revocation list expired at %v (current time: %v)", revoList.ExpiresAt, now)
+ log.Errorf("rejecting expired revocation list: %v", err)
+ return nil, err
+ }
+
+ // Ensure ExpiresAt is not in the future by more than the expected expiration window
+ // (allows some clock skew but prevents maliciously long expiration times)
+ if revoList.ExpiresAt.After(now.Add(maxRevocationSignatureAge)) {
+ err := fmt.Errorf("revocation list ExpiresAt is too far in the future: %v", revoList.ExpiresAt)
+ log.Errorf("rejecting revocation list with invalid expiration: %v", err)
+ return nil, err
+ }
+
+ // Validate signature timestamp is close to LastUpdated
+ // (prevents signing old lists with new timestamps)
+ timeDiff := signature.Timestamp.Sub(revoList.LastUpdated).Abs()
+ if timeDiff > maxClockSkew {
+ err := fmt.Errorf("signature timestamp %v differs too much from list LastUpdated %v (diff: %v)",
+ signature.Timestamp, revoList.LastUpdated, timeDiff)
+ log.Errorf("timestamp mismatch in revocation list: %v", err)
+ return nil, err
+ }
+
+ // Reconstruct the signed message: revocation_list_data || timestamp || version
+ msg := make([]byte, 0, len(data)+8)
+ msg = append(msg, data...)
+ msg = binary.LittleEndian.AppendUint64(msg, uint64(signature.Timestamp.Unix()))
+
+ if !verifyAny(publicRootKeys, msg, signature.Signature) {
+ return nil, errors.New("revocation list verification failed")
+ }
+ return revoList, nil
+}
+
+func CreateRevocationList(privateRootKey RootKey, expiration time.Duration) ([]byte, []byte, error) {
+ now := time.Now()
+ rl := RevocationList{
+ Revoked: make(map[KeyID]time.Time),
+ LastUpdated: now.UTC(),
+ ExpiresAt: now.Add(expiration).UTC(),
+ }
+
+ signature, err := signRevocationList(privateRootKey, rl)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to sign revocation list: %w", err)
+ }
+
+ rlData, err := json.Marshal(&rl)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to marshal revocation list: %w", err)
+ }
+
+ signData, err := json.Marshal(signature)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to marshal signature: %w", err)
+ }
+
+ return rlData, signData, nil
+}
+
+func ExtendRevocationList(privateRootKey RootKey, rl RevocationList, kid KeyID, expiration time.Duration) ([]byte, []byte, error) {
+ now := time.Now().UTC()
+
+ rl.Revoked[kid] = now
+ rl.LastUpdated = now
+ rl.ExpiresAt = now.Add(expiration)
+
+ signature, err := signRevocationList(privateRootKey, rl)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to sign revocation list: %w", err)
+ }
+
+ rlData, err := json.Marshal(&rl)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to marshal revocation list: %w", err)
+ }
+
+ signData, err := json.Marshal(signature)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to marshal signature: %w", err)
+ }
+
+ return rlData, signData, nil
+}
+
+func signRevocationList(privateRootKey RootKey, rl RevocationList) (*Signature, error) {
+ data, err := json.Marshal(rl)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal revocation list for signing: %w", err)
+ }
+
+ timestamp := time.Now().UTC()
+
+ msg := make([]byte, 0, len(data)+8)
+ msg = append(msg, data...)
+ msg = binary.LittleEndian.AppendUint64(msg, uint64(timestamp.Unix()))
+
+ sig := ed25519.Sign(privateRootKey.Key, msg)
+
+ signature := &Signature{
+ Signature: sig,
+ Timestamp: timestamp,
+ KeyID: privateRootKey.Metadata.ID,
+ Algorithm: "ed25519",
+ HashAlgo: "sha512",
+ }
+
+ return signature, nil
+}
diff --git a/client/internal/updatemanager/reposign/revocation_test.go b/client/internal/updatemanager/reposign/revocation_test.go
new file mode 100644
index 000000000..d6d748f3d
--- /dev/null
+++ b/client/internal/updatemanager/reposign/revocation_test.go
@@ -0,0 +1,860 @@
+package reposign
+
+import (
+ "crypto/ed25519"
+ "crypto/rand"
+ "encoding/json"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// Test RevocationList marshaling/unmarshaling
+
+func TestRevocationList_MarshalJSON(t *testing.T) {
+ pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ keyID := computeKeyID(pub)
+ revokedTime := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
+ lastUpdated := time.Date(2024, 1, 15, 11, 0, 0, 0, time.UTC)
+ expiresAt := time.Date(2024, 4, 15, 11, 0, 0, 0, time.UTC)
+
+ rl := &RevocationList{
+ Revoked: map[KeyID]time.Time{
+ keyID: revokedTime,
+ },
+ LastUpdated: lastUpdated,
+ ExpiresAt: expiresAt,
+ }
+
+ jsonData, err := json.Marshal(rl)
+ require.NoError(t, err)
+
+ // Verify it can be unmarshaled back
+ var decoded map[string]interface{}
+ err = json.Unmarshal(jsonData, &decoded)
+ require.NoError(t, err)
+
+ assert.Contains(t, decoded, "revoked")
+ assert.Contains(t, decoded, "last_updated")
+ assert.Contains(t, decoded, "expires_at")
+}
+
+func TestRevocationList_UnmarshalJSON(t *testing.T) {
+ pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ keyID := computeKeyID(pub)
+ revokedTime := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
+ lastUpdated := time.Date(2024, 1, 15, 11, 0, 0, 0, time.UTC)
+
+ jsonData := map[string]interface{}{
+ "revoked": map[string]string{
+ keyID.String(): revokedTime.Format(time.RFC3339),
+ },
+ "last_updated": lastUpdated.Format(time.RFC3339),
+ }
+
+ jsonBytes, err := json.Marshal(jsonData)
+ require.NoError(t, err)
+
+ var rl RevocationList
+ err = json.Unmarshal(jsonBytes, &rl)
+ require.NoError(t, err)
+
+ assert.Len(t, rl.Revoked, 1)
+ assert.Contains(t, rl.Revoked, keyID)
+ assert.Equal(t, lastUpdated.Unix(), rl.LastUpdated.Unix())
+}
+
+func TestRevocationList_MarshalUnmarshal_Roundtrip(t *testing.T) {
+ pub1, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+ pub2, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ keyID1 := computeKeyID(pub1)
+ keyID2 := computeKeyID(pub2)
+
+ original := &RevocationList{
+ Revoked: map[KeyID]time.Time{
+ keyID1: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC),
+ keyID2: time.Date(2024, 2, 20, 14, 45, 0, 0, time.UTC),
+ },
+ LastUpdated: time.Date(2024, 2, 20, 15, 0, 0, 0, time.UTC),
+ }
+
+ // Marshal
+ jsonData, err := original.MarshalJSON()
+ require.NoError(t, err)
+
+ // Unmarshal
+ var decoded RevocationList
+ err = decoded.UnmarshalJSON(jsonData)
+ require.NoError(t, err)
+
+ // Verify
+ assert.Len(t, decoded.Revoked, 2)
+ assert.Equal(t, original.Revoked[keyID1].Unix(), decoded.Revoked[keyID1].Unix())
+ assert.Equal(t, original.Revoked[keyID2].Unix(), decoded.Revoked[keyID2].Unix())
+ assert.Equal(t, original.LastUpdated.Unix(), decoded.LastUpdated.Unix())
+}
+
+func TestRevocationList_UnmarshalJSON_InvalidKeyID(t *testing.T) {
+ jsonData := []byte(`{
+ "revoked": {
+ "invalid_key_id": "2024-01-15T10:30:00Z"
+ },
+ "last_updated": "2024-01-15T11:00:00Z"
+ }`)
+
+ var rl RevocationList
+ err := json.Unmarshal(jsonData, &rl)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to parse KeyID")
+}
+
+func TestRevocationList_EmptyRevoked(t *testing.T) {
+ rl := &RevocationList{
+ Revoked: make(map[KeyID]time.Time),
+ LastUpdated: time.Now().UTC(),
+ }
+
+ jsonData, err := rl.MarshalJSON()
+ require.NoError(t, err)
+
+ var decoded RevocationList
+ err = decoded.UnmarshalJSON(jsonData)
+ require.NoError(t, err)
+
+ assert.Empty(t, decoded.Revoked)
+ assert.NotNil(t, decoded.Revoked)
+}
+
+// Test ParseRevocationList
+
+func TestParseRevocationList_Valid(t *testing.T) {
+ pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ keyID := computeKeyID(pub)
+ revokedTime := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
+ lastUpdated := time.Date(2024, 1, 15, 11, 0, 0, 0, time.UTC)
+
+ rl := RevocationList{
+ Revoked: map[KeyID]time.Time{
+ keyID: revokedTime,
+ },
+ LastUpdated: lastUpdated,
+ ExpiresAt: time.Date(2025, 2, 20, 14, 45, 0, 0, time.UTC),
+ }
+
+ jsonData, err := rl.MarshalJSON()
+ require.NoError(t, err)
+
+ parsed, err := ParseRevocationList(jsonData)
+ require.NoError(t, err)
+ assert.NotNil(t, parsed)
+ assert.Len(t, parsed.Revoked, 1)
+ assert.Equal(t, lastUpdated.Unix(), parsed.LastUpdated.Unix())
+}
+
+func TestParseRevocationList_InvalidJSON(t *testing.T) {
+ invalidJSON := []byte("not valid json")
+
+ _, err := ParseRevocationList(invalidJSON)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to unmarshal")
+}
+
+func TestParseRevocationList_MissingLastUpdated(t *testing.T) {
+ jsonData := []byte(`{
+ "revoked": {}
+ }`)
+
+ _, err := ParseRevocationList(jsonData)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "missing last_updated")
+}
+
+func TestParseRevocationList_EmptyObject(t *testing.T) {
+ jsonData := []byte(`{}`)
+
+ _, err := ParseRevocationList(jsonData)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "missing last_updated")
+}
+
+func TestParseRevocationList_NilRevoked(t *testing.T) {
+ lastUpdated := time.Now().UTC()
+ expiresAt := lastUpdated.Add(90 * 24 * time.Hour)
+ jsonData := []byte(`{
+ "last_updated": "` + lastUpdated.Format(time.RFC3339) + `",
+ "expires_at": "` + expiresAt.Format(time.RFC3339) + `"
+ }`)
+
+ parsed, err := ParseRevocationList(jsonData)
+ require.NoError(t, err)
+ assert.NotNil(t, parsed.Revoked)
+ assert.Empty(t, parsed.Revoked)
+}
+
+func TestParseRevocationList_MissingExpiresAt(t *testing.T) {
+ lastUpdated := time.Now().UTC()
+ jsonData := []byte(`{
+ "revoked": {},
+ "last_updated": "` + lastUpdated.Format(time.RFC3339) + `"
+ }`)
+
+ _, err := ParseRevocationList(jsonData)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "missing expires_at")
+}
+
+// Test ValidateRevocationList
+
+func TestValidateRevocationList_Valid(t *testing.T) {
+ // Generate root key
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Create revocation list
+ rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ signature, err := ParseSignature(sigData)
+ require.NoError(t, err)
+
+ // Validate
+ rl, err := ValidateRevocationList(rootKeys, rlData, *signature)
+ require.NoError(t, err)
+ assert.NotNil(t, rl)
+ assert.Empty(t, rl.Revoked)
+}
+
+func TestValidateRevocationList_InvalidSignature(t *testing.T) {
+ // Generate root key
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Create revocation list
+ rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ // Create invalid signature
+ invalidSig := Signature{
+ Signature: make([]byte, 64),
+ Timestamp: time.Now().UTC(),
+ KeyID: computeKeyID(rootPub),
+ Algorithm: "ed25519",
+ HashAlgo: "sha512",
+ }
+
+ // Validate should fail
+ _, err = ValidateRevocationList(rootKeys, rlData, invalidSig)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "verification failed")
+}
+
+func TestValidateRevocationList_FutureTimestamp(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ signature, err := ParseSignature(sigData)
+ require.NoError(t, err)
+
+ // Modify timestamp to be in the future
+ signature.Timestamp = time.Now().UTC().Add(10 * time.Minute)
+
+ _, err = ValidateRevocationList(rootKeys, rlData, *signature)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "in the future")
+}
+
+func TestValidateRevocationList_TooOld(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ signature, err := ParseSignature(sigData)
+ require.NoError(t, err)
+
+ // Modify timestamp to be too old
+ signature.Timestamp = time.Now().UTC().Add(-20 * 365 * 24 * time.Hour)
+
+ _, err = ValidateRevocationList(rootKeys, rlData, *signature)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "too old")
+}
+
+func TestValidateRevocationList_InvalidJSON(t *testing.T) {
+ rootPub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ signature := Signature{
+ Signature: make([]byte, 64),
+ Timestamp: time.Now().UTC(),
+ KeyID: computeKeyID(rootPub),
+ Algorithm: "ed25519",
+ HashAlgo: "sha512",
+ }
+
+ _, err = ValidateRevocationList(rootKeys, []byte("invalid json"), signature)
+ assert.Error(t, err)
+}
+
+func TestValidateRevocationList_FutureLastUpdated(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Create revocation list with future LastUpdated
+ rl := RevocationList{
+ Revoked: make(map[KeyID]time.Time),
+ LastUpdated: time.Now().UTC().Add(10 * time.Minute),
+ ExpiresAt: time.Now().UTC().Add(365 * 24 * time.Hour),
+ }
+
+ rlData, err := json.Marshal(rl)
+ require.NoError(t, err)
+
+ // Sign it
+ sig, err := signRevocationList(rootKey, rl)
+ require.NoError(t, err)
+
+ _, err = ValidateRevocationList(rootKeys, rlData, *sig)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "LastUpdated is in the future")
+}
+
+func TestValidateRevocationList_TimestampMismatch(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Create revocation list with LastUpdated far in the past
+ rl := RevocationList{
+ Revoked: make(map[KeyID]time.Time),
+ LastUpdated: time.Now().UTC().Add(-1 * time.Hour),
+ ExpiresAt: time.Now().UTC().Add(365 * 24 * time.Hour),
+ }
+
+ rlData, err := json.Marshal(rl)
+ require.NoError(t, err)
+
+ // Sign it with current timestamp
+ sig, err := signRevocationList(rootKey, rl)
+ require.NoError(t, err)
+
+ // Modify signature timestamp to differ too much from LastUpdated
+ sig.Timestamp = time.Now().UTC()
+
+ _, err = ValidateRevocationList(rootKeys, rlData, *sig)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "differs too much")
+}
+
+func TestValidateRevocationList_Expired(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Create revocation list that expired in the past
+ now := time.Now().UTC()
+ rl := RevocationList{
+ Revoked: make(map[KeyID]time.Time),
+ LastUpdated: now.Add(-100 * 24 * time.Hour),
+ ExpiresAt: now.Add(-10 * 24 * time.Hour), // Expired 10 days ago
+ }
+
+ rlData, err := json.Marshal(rl)
+ require.NoError(t, err)
+
+ // Sign it
+ sig, err := signRevocationList(rootKey, rl)
+ require.NoError(t, err)
+ // Adjust signature timestamp to match LastUpdated
+ sig.Timestamp = rl.LastUpdated
+
+ _, err = ValidateRevocationList(rootKeys, rlData, *sig)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "expired")
+}
+
+func TestValidateRevocationList_ExpiresAtTooFarInFuture(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Create revocation list with ExpiresAt too far in the future (beyond maxRevocationSignatureAge)
+ now := time.Now().UTC()
+ rl := RevocationList{
+ Revoked: make(map[KeyID]time.Time),
+ LastUpdated: now,
+ ExpiresAt: now.Add(15 * 365 * 24 * time.Hour), // 15 years in the future
+ }
+
+ rlData, err := json.Marshal(rl)
+ require.NoError(t, err)
+
+ // Sign it
+ sig, err := signRevocationList(rootKey, rl)
+ require.NoError(t, err)
+
+ _, err = ValidateRevocationList(rootKeys, rlData, *sig)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "too far in the future")
+}
+
+// Test CreateRevocationList
+
+func TestCreateRevocationList_Valid(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
+ require.NoError(t, err)
+ assert.NotEmpty(t, rlData)
+ assert.NotEmpty(t, sigData)
+
+ // Verify it can be parsed
+ rl, err := ParseRevocationList(rlData)
+ require.NoError(t, err)
+ assert.Empty(t, rl.Revoked)
+ assert.False(t, rl.LastUpdated.IsZero())
+
+ // Verify signature can be parsed
+ sig, err := ParseSignature(sigData)
+ require.NoError(t, err)
+ assert.NotEmpty(t, sig.Signature)
+}
+
+// Test ExtendRevocationList
+
+func TestExtendRevocationList_AddKey(t *testing.T) {
+ // Generate root key
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Create empty revocation list
+ rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ rl, err := ParseRevocationList(rlData)
+ require.NoError(t, err)
+ assert.Empty(t, rl.Revoked)
+
+ // Generate a key to revoke
+ revokedPub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+ revokedKeyID := computeKeyID(revokedPub)
+
+ // Extend the revocation list
+ newRLData, newSigData, err := ExtendRevocationList(rootKey, *rl, revokedKeyID, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ // Verify the new list
+ newRL, err := ParseRevocationList(newRLData)
+ require.NoError(t, err)
+ assert.Len(t, newRL.Revoked, 1)
+ assert.Contains(t, newRL.Revoked, revokedKeyID)
+
+ // Verify signature
+ sig, err := ParseSignature(newSigData)
+ require.NoError(t, err)
+ assert.NotEmpty(t, sig.Signature)
+}
+
+func TestExtendRevocationList_MultipleKeys(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Create empty revocation list
+ rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ rl, err := ParseRevocationList(rlData)
+ require.NoError(t, err)
+
+ // Add first key
+ key1Pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+ key1ID := computeKeyID(key1Pub)
+
+ rlData, _, err = ExtendRevocationList(rootKey, *rl, key1ID, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ rl, err = ParseRevocationList(rlData)
+ require.NoError(t, err)
+ assert.Len(t, rl.Revoked, 1)
+
+ // Add second key
+ key2Pub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+ key2ID := computeKeyID(key2Pub)
+
+ rlData, _, err = ExtendRevocationList(rootKey, *rl, key2ID, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ rl, err = ParseRevocationList(rlData)
+ require.NoError(t, err)
+ assert.Len(t, rl.Revoked, 2)
+ assert.Contains(t, rl.Revoked, key1ID)
+ assert.Contains(t, rl.Revoked, key2ID)
+}
+
+func TestExtendRevocationList_DuplicateKey(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Create empty revocation list
+ rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ rl, err := ParseRevocationList(rlData)
+ require.NoError(t, err)
+
+ // Add a key
+ keyPub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+ keyID := computeKeyID(keyPub)
+
+ rlData, _, err = ExtendRevocationList(rootKey, *rl, keyID, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ rl, err = ParseRevocationList(rlData)
+ require.NoError(t, err)
+ firstRevocationTime := rl.Revoked[keyID]
+
+ // Wait a bit
+ time.Sleep(10 * time.Millisecond)
+
+ // Add the same key again
+ rlData, _, err = ExtendRevocationList(rootKey, *rl, keyID, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ rl, err = ParseRevocationList(rlData)
+ require.NoError(t, err)
+ assert.Len(t, rl.Revoked, 1)
+
+ // The revocation time should be updated
+ secondRevocationTime := rl.Revoked[keyID]
+ assert.True(t, secondRevocationTime.After(firstRevocationTime) || secondRevocationTime.Equal(firstRevocationTime))
+}
+
+func TestExtendRevocationList_UpdatesLastUpdated(t *testing.T) {
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Create revocation list
+ rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ rl, err := ParseRevocationList(rlData)
+ require.NoError(t, err)
+ firstLastUpdated := rl.LastUpdated
+
+ // Wait a bit
+ time.Sleep(10 * time.Millisecond)
+
+ // Extend list
+ keyPub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+ keyID := computeKeyID(keyPub)
+
+ rlData, _, err = ExtendRevocationList(rootKey, *rl, keyID, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ rl, err = ParseRevocationList(rlData)
+ require.NoError(t, err)
+
+ // LastUpdated should be updated
+ assert.True(t, rl.LastUpdated.After(firstLastUpdated))
+}
+
+// Integration test
+
+func TestRevocationList_FullWorkflow(t *testing.T) {
+ // Create root key
+ rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ rootKey := RootKey{
+ PrivateKey{
+ Key: rootPriv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ rootKeys := []PublicKey{
+ {
+ Key: rootPub,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(rootPub),
+ CreatedAt: time.Now().UTC(),
+ },
+ },
+ }
+
+ // Step 1: Create empty revocation list
+ rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ // Step 2: Validate it
+ sig, err := ParseSignature(sigData)
+ require.NoError(t, err)
+
+ rl, err := ValidateRevocationList(rootKeys, rlData, *sig)
+ require.NoError(t, err)
+ assert.Empty(t, rl.Revoked)
+
+ // Step 3: Revoke a key
+ revokedPub, _, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+ revokedKeyID := computeKeyID(revokedPub)
+
+ rlData, sigData, err = ExtendRevocationList(rootKey, *rl, revokedKeyID, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ // Step 4: Validate the extended list
+ sig, err = ParseSignature(sigData)
+ require.NoError(t, err)
+
+ rl, err = ValidateRevocationList(rootKeys, rlData, *sig)
+ require.NoError(t, err)
+ assert.Len(t, rl.Revoked, 1)
+ assert.Contains(t, rl.Revoked, revokedKeyID)
+
+ // Step 5: Verify the revocation time is reasonable
+ revTime := rl.Revoked[revokedKeyID]
+ now := time.Now().UTC()
+ assert.True(t, revTime.Before(now) || revTime.Equal(now))
+ assert.True(t, now.Sub(revTime) < time.Minute)
+}
diff --git a/client/internal/updatemanager/reposign/root.go b/client/internal/updatemanager/reposign/root.go
new file mode 100644
index 000000000..2c3ca54a0
--- /dev/null
+++ b/client/internal/updatemanager/reposign/root.go
@@ -0,0 +1,120 @@
+package reposign
+
+import (
+ "crypto/ed25519"
+ "crypto/rand"
+ "encoding/binary"
+ "encoding/json"
+ "encoding/pem"
+ "fmt"
+ "time"
+)
+
+const (
+ tagRootPrivate = "ROOT PRIVATE KEY"
+ tagRootPublic = "ROOT PUBLIC KEY"
+)
+
+// RootKey is a root Key used to sign signing keys
+type RootKey struct {
+ PrivateKey
+}
+
+func (k RootKey) String() string {
+ return fmt.Sprintf(
+ "RootKey[ID=%s, CreatedAt=%s, ExpiresAt=%s]",
+ k.Metadata.ID,
+ k.Metadata.CreatedAt.Format(time.RFC3339),
+ k.Metadata.ExpiresAt.Format(time.RFC3339),
+ )
+}
+
+func ParseRootKey(privKeyPEM []byte) (*RootKey, error) {
+ pk, err := parsePrivateKey(privKeyPEM, tagRootPrivate)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse root Key: %w", err)
+ }
+ return &RootKey{pk}, nil
+}
+
+// ParseRootPublicKey parses a root public key from PEM format
+func ParseRootPublicKey(pubKeyPEM []byte) (PublicKey, error) {
+ pk, _, err := parsePublicKey(pubKeyPEM, tagRootPublic)
+ if err != nil {
+ return PublicKey{}, fmt.Errorf("failed to parse root public key: %w", err)
+ }
+ return pk, nil
+}
+
+// GenerateRootKey generates a new root Key pair with Metadata
+func GenerateRootKey(expiration time.Duration) (*RootKey, []byte, []byte, error) {
+ now := time.Now()
+ expirationTime := now.Add(expiration)
+ pub, priv, err := ed25519.GenerateKey(rand.Reader)
+ if err != nil {
+ return nil, nil, nil, err
+ }
+
+ metadata := KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: now.UTC(),
+ ExpiresAt: expirationTime.UTC(),
+ }
+
+ rk := &RootKey{
+ PrivateKey{
+ Key: priv,
+ Metadata: metadata,
+ },
+ }
+
+ // Marshal PrivateKey struct to JSON
+ privJSON, err := json.Marshal(rk.PrivateKey)
+ if err != nil {
+ return nil, nil, nil, fmt.Errorf("failed to marshal private key: %w", err)
+ }
+
+ // Marshal PublicKey struct to JSON
+ pubKey := PublicKey{
+ Key: pub,
+ Metadata: metadata,
+ }
+ pubJSON, err := json.Marshal(pubKey)
+ if err != nil {
+ return nil, nil, nil, fmt.Errorf("failed to marshal public key: %w", err)
+ }
+
+ // Encode to PEM with metadata embedded in bytes
+ privPEM := pem.EncodeToMemory(&pem.Block{
+ Type: tagRootPrivate,
+ Bytes: privJSON,
+ })
+
+ pubPEM := pem.EncodeToMemory(&pem.Block{
+ Type: tagRootPublic,
+ Bytes: pubJSON,
+ })
+
+ return rk, privPEM, pubPEM, nil
+}
+
+func SignArtifactKey(rootKey RootKey, data []byte) ([]byte, error) {
+ timestamp := time.Now().UTC()
+
+ // This ensures the timestamp is cryptographically bound to the signature
+ msg := make([]byte, 0, len(data)+8)
+ msg = append(msg, data...)
+ msg = binary.LittleEndian.AppendUint64(msg, uint64(timestamp.Unix()))
+
+ sig := ed25519.Sign(rootKey.Key, msg)
+ // Create signature bundle with timestamp and Metadata
+ bundle := Signature{
+ Signature: sig,
+ Timestamp: timestamp,
+ KeyID: rootKey.Metadata.ID,
+ Algorithm: "ed25519",
+ HashAlgo: "sha512",
+ }
+
+ return json.Marshal(bundle)
+}
diff --git a/client/internal/updatemanager/reposign/root_test.go b/client/internal/updatemanager/reposign/root_test.go
new file mode 100644
index 000000000..e75e29729
--- /dev/null
+++ b/client/internal/updatemanager/reposign/root_test.go
@@ -0,0 +1,476 @@
+package reposign
+
+import (
+ "crypto/ed25519"
+ "crypto/rand"
+ "encoding/binary"
+ "encoding/json"
+ "encoding/pem"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// Test RootKey.String()
+
+func TestRootKey_String(t *testing.T) {
+ pub, priv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ createdAt := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
+ expiresAt := time.Date(2034, 1, 15, 10, 30, 0, 0, time.UTC)
+
+ rk := RootKey{
+ PrivateKey{
+ Key: priv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: createdAt,
+ ExpiresAt: expiresAt,
+ },
+ },
+ }
+
+ str := rk.String()
+ assert.Contains(t, str, "RootKey")
+ assert.Contains(t, str, computeKeyID(pub).String())
+ assert.Contains(t, str, "2024-01-15")
+ assert.Contains(t, str, "2034-01-15")
+}
+
+func TestRootKey_String_NoExpiration(t *testing.T) {
+ pub, priv, err := ed25519.GenerateKey(rand.Reader)
+ require.NoError(t, err)
+
+ createdAt := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
+
+ rk := RootKey{
+ PrivateKey{
+ Key: priv,
+ Metadata: KeyMetadata{
+ ID: computeKeyID(pub),
+ CreatedAt: createdAt,
+ ExpiresAt: time.Time{}, // No expiration
+ },
+ },
+ }
+
+ str := rk.String()
+ assert.Contains(t, str, "RootKey")
+ assert.Contains(t, str, "0001-01-01") // Zero time format
+}
+
+// Test GenerateRootKey
+
+func TestGenerateRootKey_Valid(t *testing.T) {
+ expiration := 10 * 365 * 24 * time.Hour // 10 years
+
+ rk, privPEM, pubPEM, err := GenerateRootKey(expiration)
+ require.NoError(t, err)
+ assert.NotNil(t, rk)
+ assert.NotEmpty(t, privPEM)
+ assert.NotEmpty(t, pubPEM)
+
+ // Verify the key has correct metadata
+ assert.False(t, rk.Metadata.CreatedAt.IsZero())
+ assert.False(t, rk.Metadata.ExpiresAt.IsZero())
+ assert.True(t, rk.Metadata.ExpiresAt.After(rk.Metadata.CreatedAt))
+
+ // Verify expiration is approximately correct
+ expectedExpiration := time.Now().Add(expiration)
+ timeDiff := rk.Metadata.ExpiresAt.Sub(expectedExpiration)
+ assert.True(t, timeDiff < time.Minute && timeDiff > -time.Minute)
+}
+
+func TestGenerateRootKey_ShortExpiration(t *testing.T) {
+ expiration := 24 * time.Hour // 1 day
+
+ rk, _, _, err := GenerateRootKey(expiration)
+ require.NoError(t, err)
+ assert.NotNil(t, rk)
+
+ // Verify expiration
+ expectedExpiration := time.Now().Add(expiration)
+ timeDiff := rk.Metadata.ExpiresAt.Sub(expectedExpiration)
+ assert.True(t, timeDiff < time.Minute && timeDiff > -time.Minute)
+}
+
+func TestGenerateRootKey_ZeroExpiration(t *testing.T) {
+ rk, _, _, err := GenerateRootKey(0)
+ require.NoError(t, err)
+ assert.NotNil(t, rk)
+
+ // With zero expiration, ExpiresAt should be equal to CreatedAt
+ assert.Equal(t, rk.Metadata.CreatedAt, rk.Metadata.ExpiresAt)
+}
+
+func TestGenerateRootKey_PEMFormat(t *testing.T) {
+ rk, privPEM, pubPEM, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ // Verify private key PEM
+ privBlock, _ := pem.Decode(privPEM)
+ require.NotNil(t, privBlock)
+ assert.Equal(t, tagRootPrivate, privBlock.Type)
+
+ var privKey PrivateKey
+ err = json.Unmarshal(privBlock.Bytes, &privKey)
+ require.NoError(t, err)
+ assert.Equal(t, rk.Key, privKey.Key)
+
+ // Verify public key PEM
+ pubBlock, _ := pem.Decode(pubPEM)
+ require.NotNil(t, pubBlock)
+ assert.Equal(t, tagRootPublic, pubBlock.Type)
+
+ var pubKey PublicKey
+ err = json.Unmarshal(pubBlock.Bytes, &pubKey)
+ require.NoError(t, err)
+ assert.Equal(t, rk.Metadata.ID, pubKey.Metadata.ID)
+}
+
+func TestGenerateRootKey_KeySize(t *testing.T) {
+ rk, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ // Ed25519 private key should be 64 bytes
+ assert.Equal(t, ed25519.PrivateKeySize, len(rk.Key))
+
+ // Ed25519 public key should be 32 bytes
+ pubKey := rk.Key.Public().(ed25519.PublicKey)
+ assert.Equal(t, ed25519.PublicKeySize, len(pubKey))
+}
+
+func TestGenerateRootKey_UniqueKeys(t *testing.T) {
+ rk1, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ rk2, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ // Different keys should have different IDs
+ assert.NotEqual(t, rk1.Metadata.ID, rk2.Metadata.ID)
+ assert.NotEqual(t, rk1.Key, rk2.Key)
+}
+
+// Test ParseRootKey
+
+func TestParseRootKey_Valid(t *testing.T) {
+ original, privPEM, _, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ parsed, err := ParseRootKey(privPEM)
+ require.NoError(t, err)
+ assert.NotNil(t, parsed)
+
+ // Verify the parsed key matches the original
+ assert.Equal(t, original.Key, parsed.Key)
+ assert.Equal(t, original.Metadata.ID, parsed.Metadata.ID)
+ assert.Equal(t, original.Metadata.CreatedAt.Unix(), parsed.Metadata.CreatedAt.Unix())
+ assert.Equal(t, original.Metadata.ExpiresAt.Unix(), parsed.Metadata.ExpiresAt.Unix())
+}
+
+func TestParseRootKey_InvalidPEM(t *testing.T) {
+ _, err := ParseRootKey([]byte("not a valid PEM"))
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to parse")
+}
+
+func TestParseRootKey_EmptyData(t *testing.T) {
+ _, err := ParseRootKey([]byte{})
+ assert.Error(t, err)
+}
+
+func TestParseRootKey_WrongType(t *testing.T) {
+ // Generate an artifact key instead of root key
+ rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ artifactKey, privPEM, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+
+ // Try to parse artifact key as root key
+ _, err = ParseRootKey(privPEM)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "PEM type")
+
+ // Just to use artifactKey to avoid unused variable warning
+ _ = artifactKey
+}
+
+func TestParseRootKey_CorruptedJSON(t *testing.T) {
+ // Create PEM with corrupted JSON
+ corruptedPEM := pem.EncodeToMemory(&pem.Block{
+ Type: tagRootPrivate,
+ Bytes: []byte("corrupted json data"),
+ })
+
+ _, err := ParseRootKey(corruptedPEM)
+ assert.Error(t, err)
+}
+
+func TestParseRootKey_InvalidKeySize(t *testing.T) {
+ // Create a key with invalid size
+ invalidKey := PrivateKey{
+ Key: []byte{0x01, 0x02, 0x03}, // Too short
+ Metadata: KeyMetadata{
+ ID: KeyID{},
+ CreatedAt: time.Now().UTC(),
+ },
+ }
+
+ privJSON, err := json.Marshal(invalidKey)
+ require.NoError(t, err)
+
+ invalidPEM := pem.EncodeToMemory(&pem.Block{
+ Type: tagRootPrivate,
+ Bytes: privJSON,
+ })
+
+ _, err = ParseRootKey(invalidPEM)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "incorrect Ed25519 private key size")
+}
+
+func TestParseRootKey_Roundtrip(t *testing.T) {
+ // Generate a key
+ original, privPEM, _, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ // Parse it
+ parsed, err := ParseRootKey(privPEM)
+ require.NoError(t, err)
+
+ // Generate PEM again from parsed key
+ privJSON2, err := json.Marshal(parsed.PrivateKey)
+ require.NoError(t, err)
+
+ privPEM2 := pem.EncodeToMemory(&pem.Block{
+ Type: tagRootPrivate,
+ Bytes: privJSON2,
+ })
+
+ // Parse again
+ parsed2, err := ParseRootKey(privPEM2)
+ require.NoError(t, err)
+
+ // Should still match original
+ assert.Equal(t, original.Key, parsed2.Key)
+ assert.Equal(t, original.Metadata.ID, parsed2.Metadata.ID)
+}
+
+// Test SignArtifactKey
+
+func TestSignArtifactKey_Valid(t *testing.T) {
+ rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ data := []byte("test data to sign")
+ sigData, err := SignArtifactKey(*rootKey, data)
+ require.NoError(t, err)
+ assert.NotEmpty(t, sigData)
+
+ // Parse and verify signature
+ sig, err := ParseSignature(sigData)
+ require.NoError(t, err)
+ assert.NotEmpty(t, sig.Signature)
+ assert.Equal(t, rootKey.Metadata.ID, sig.KeyID)
+ assert.Equal(t, "ed25519", sig.Algorithm)
+ assert.Equal(t, "sha512", sig.HashAlgo)
+ assert.False(t, sig.Timestamp.IsZero())
+}
+
+func TestSignArtifactKey_EmptyData(t *testing.T) {
+ rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ sigData, err := SignArtifactKey(*rootKey, []byte{})
+ require.NoError(t, err)
+ assert.NotEmpty(t, sigData)
+
+ // Should still be able to parse
+ sig, err := ParseSignature(sigData)
+ require.NoError(t, err)
+ assert.NotEmpty(t, sig.Signature)
+}
+
+func TestSignArtifactKey_Verify(t *testing.T) {
+ rootKey, _, pubPEM, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ // Parse public key
+ pubKey, _, err := parsePublicKey(pubPEM, tagRootPublic)
+ require.NoError(t, err)
+
+ // Sign some data
+ data := []byte("test data for verification")
+ sigData, err := SignArtifactKey(*rootKey, data)
+ require.NoError(t, err)
+
+ // Parse signature
+ sig, err := ParseSignature(sigData)
+ require.NoError(t, err)
+
+ // Reconstruct message
+ msg := make([]byte, 0, len(data)+8)
+ msg = append(msg, data...)
+ msg = binary.LittleEndian.AppendUint64(msg, uint64(sig.Timestamp.Unix()))
+
+ // Verify signature
+ valid := ed25519.Verify(pubKey.Key, msg, sig.Signature)
+ assert.True(t, valid)
+}
+
+func TestSignArtifactKey_DifferentData(t *testing.T) {
+ rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ data1 := []byte("data1")
+ data2 := []byte("data2")
+
+ sig1, err := SignArtifactKey(*rootKey, data1)
+ require.NoError(t, err)
+
+ sig2, err := SignArtifactKey(*rootKey, data2)
+ require.NoError(t, err)
+
+ // Different data should produce different signatures
+ assert.NotEqual(t, sig1, sig2)
+}
+
+func TestSignArtifactKey_MultipleSignatures(t *testing.T) {
+ rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ data := []byte("test data")
+
+ // Sign twice with a small delay
+ sig1, err := SignArtifactKey(*rootKey, data)
+ require.NoError(t, err)
+
+ time.Sleep(10 * time.Millisecond)
+
+ sig2, err := SignArtifactKey(*rootKey, data)
+ require.NoError(t, err)
+
+ // Signatures should be different due to different timestamps
+ assert.NotEqual(t, sig1, sig2)
+
+ // Parse both signatures
+ parsed1, err := ParseSignature(sig1)
+ require.NoError(t, err)
+
+ parsed2, err := ParseSignature(sig2)
+ require.NoError(t, err)
+
+ // Timestamps should be different
+ assert.True(t, parsed2.Timestamp.After(parsed1.Timestamp))
+}
+
+func TestSignArtifactKey_LargeData(t *testing.T) {
+ rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ // Create 1MB of data
+ largeData := make([]byte, 1024*1024)
+ for i := range largeData {
+ largeData[i] = byte(i % 256)
+ }
+
+ sigData, err := SignArtifactKey(*rootKey, largeData)
+ require.NoError(t, err)
+ assert.NotEmpty(t, sigData)
+
+ // Verify signature can be parsed
+ sig, err := ParseSignature(sigData)
+ require.NoError(t, err)
+ assert.NotEmpty(t, sig.Signature)
+}
+
+func TestSignArtifactKey_TimestampInSignature(t *testing.T) {
+ rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ beforeSign := time.Now().UTC()
+ data := []byte("test data")
+ sigData, err := SignArtifactKey(*rootKey, data)
+ require.NoError(t, err)
+ afterSign := time.Now().UTC()
+
+ sig, err := ParseSignature(sigData)
+ require.NoError(t, err)
+
+ // Timestamp should be between before and after
+ assert.True(t, sig.Timestamp.After(beforeSign.Add(-time.Second)))
+ assert.True(t, sig.Timestamp.Before(afterSign.Add(time.Second)))
+}
+
+// Integration test
+
+func TestRootKey_FullWorkflow(t *testing.T) {
+ // Step 1: Generate root key
+ rootKey, privPEM, pubPEM, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
+ require.NoError(t, err)
+ assert.NotNil(t, rootKey)
+ assert.NotEmpty(t, privPEM)
+ assert.NotEmpty(t, pubPEM)
+
+ // Step 2: Parse the private key back
+ parsedRootKey, err := ParseRootKey(privPEM)
+ require.NoError(t, err)
+ assert.Equal(t, rootKey.Key, parsedRootKey.Key)
+ assert.Equal(t, rootKey.Metadata.ID, parsedRootKey.Metadata.ID)
+
+ // Step 3: Generate an artifact key using root key
+ artifactKey, _, artifactPubPEM, artifactSig, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+ assert.NotNil(t, artifactKey)
+
+ // Step 4: Verify the artifact key signature
+ pubKey, _, err := parsePublicKey(pubPEM, tagRootPublic)
+ require.NoError(t, err)
+
+ sig, err := ParseSignature(artifactSig)
+ require.NoError(t, err)
+
+ artifactPubKey, _, err := parsePublicKey(artifactPubPEM, tagArtifactPublic)
+ require.NoError(t, err)
+
+ // Reconstruct message - SignArtifactKey signs the PEM, not the JSON
+ msg := make([]byte, 0, len(artifactPubPEM)+8)
+ msg = append(msg, artifactPubPEM...)
+ msg = binary.LittleEndian.AppendUint64(msg, uint64(sig.Timestamp.Unix()))
+
+ // Verify with root public key
+ valid := ed25519.Verify(pubKey.Key, msg, sig.Signature)
+ assert.True(t, valid, "Artifact key signature should be valid")
+
+ // Step 5: Use artifact key to sign data
+ testData := []byte("This is test artifact data")
+ dataSig, err := SignData(*artifactKey, testData)
+ require.NoError(t, err)
+ assert.NotEmpty(t, dataSig)
+
+ // Step 6: Verify the artifact data signature
+ dataSigParsed, err := ParseSignature(dataSig)
+ require.NoError(t, err)
+
+ err = ValidateArtifact([]PublicKey{artifactPubKey}, testData, *dataSigParsed)
+ assert.NoError(t, err, "Artifact data signature should be valid")
+}
+
+func TestRootKey_ExpiredKeyWorkflow(t *testing.T) {
+ // Generate a root key that expires very soon
+ rootKey, _, _, err := GenerateRootKey(1 * time.Millisecond)
+ require.NoError(t, err)
+
+ // Wait for expiration
+ time.Sleep(10 * time.Millisecond)
+
+ // Try to generate artifact key with expired root key
+ _, _, _, _, err = GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "expired")
+}
diff --git a/client/internal/updatemanager/reposign/signature.go b/client/internal/updatemanager/reposign/signature.go
new file mode 100644
index 000000000..c7f06e94e
--- /dev/null
+++ b/client/internal/updatemanager/reposign/signature.go
@@ -0,0 +1,24 @@
+package reposign
+
+import (
+ "encoding/json"
+ "time"
+)
+
+// Signature contains a signature with associated Metadata
+type Signature struct {
+ Signature []byte `json:"signature"`
+ Timestamp time.Time `json:"timestamp"`
+ KeyID KeyID `json:"key_id"`
+ Algorithm string `json:"algorithm"` // "ed25519"
+ HashAlgo string `json:"hash_algo"` // "blake2s" or sha512
+}
+
+func ParseSignature(data []byte) (*Signature, error) {
+ var signature Signature
+ if err := json.Unmarshal(data, &signature); err != nil {
+ return nil, err
+ }
+
+ return &signature, nil
+}
diff --git a/client/internal/updatemanager/reposign/signature_test.go b/client/internal/updatemanager/reposign/signature_test.go
new file mode 100644
index 000000000..1960c5518
--- /dev/null
+++ b/client/internal/updatemanager/reposign/signature_test.go
@@ -0,0 +1,277 @@
+package reposign
+
+import (
+ "encoding/json"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestParseSignature_Valid(t *testing.T) {
+ timestamp := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
+ keyID, err := ParseKeyID("0123456789abcdef")
+ require.NoError(t, err)
+
+ signatureData := []byte{0x01, 0x02, 0x03, 0x04}
+
+ jsonData, err := json.Marshal(Signature{
+ Signature: signatureData,
+ Timestamp: timestamp,
+ KeyID: keyID,
+ Algorithm: "ed25519",
+ HashAlgo: "blake2s",
+ })
+ require.NoError(t, err)
+
+ sig, err := ParseSignature(jsonData)
+ require.NoError(t, err)
+ assert.NotNil(t, sig)
+ assert.Equal(t, signatureData, sig.Signature)
+ assert.Equal(t, timestamp.Unix(), sig.Timestamp.Unix())
+ assert.Equal(t, keyID, sig.KeyID)
+ assert.Equal(t, "ed25519", sig.Algorithm)
+ assert.Equal(t, "blake2s", sig.HashAlgo)
+}
+
+func TestParseSignature_InvalidJSON(t *testing.T) {
+ invalidJSON := []byte(`{invalid json}`)
+
+ sig, err := ParseSignature(invalidJSON)
+ assert.Error(t, err)
+ assert.Nil(t, sig)
+}
+
+func TestParseSignature_EmptyData(t *testing.T) {
+ emptyJSON := []byte(`{}`)
+
+ sig, err := ParseSignature(emptyJSON)
+ require.NoError(t, err)
+ assert.NotNil(t, sig)
+ assert.Empty(t, sig.Signature)
+ assert.True(t, sig.Timestamp.IsZero())
+ assert.Equal(t, KeyID{}, sig.KeyID)
+ assert.Empty(t, sig.Algorithm)
+ assert.Empty(t, sig.HashAlgo)
+}
+
+func TestParseSignature_MissingFields(t *testing.T) {
+ // JSON with only some fields
+ partialJSON := []byte(`{
+ "signature": "AQIDBA==",
+ "algorithm": "ed25519"
+ }`)
+
+ sig, err := ParseSignature(partialJSON)
+ require.NoError(t, err)
+ assert.NotNil(t, sig)
+ assert.NotEmpty(t, sig.Signature)
+ assert.Equal(t, "ed25519", sig.Algorithm)
+ assert.True(t, sig.Timestamp.IsZero())
+}
+
+func TestSignature_MarshalUnmarshal_Roundtrip(t *testing.T) {
+ timestamp := time.Date(2024, 6, 20, 14, 45, 30, 0, time.UTC)
+ keyID, err := ParseKeyID("fedcba9876543210")
+ require.NoError(t, err)
+
+ original := Signature{
+ Signature: []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe},
+ Timestamp: timestamp,
+ KeyID: keyID,
+ Algorithm: "ed25519",
+ HashAlgo: "sha512",
+ }
+
+ // Marshal
+ jsonData, err := json.Marshal(original)
+ require.NoError(t, err)
+
+ // Unmarshal
+ parsed, err := ParseSignature(jsonData)
+ require.NoError(t, err)
+
+ // Verify
+ assert.Equal(t, original.Signature, parsed.Signature)
+ assert.Equal(t, original.Timestamp.Unix(), parsed.Timestamp.Unix())
+ assert.Equal(t, original.KeyID, parsed.KeyID)
+ assert.Equal(t, original.Algorithm, parsed.Algorithm)
+ assert.Equal(t, original.HashAlgo, parsed.HashAlgo)
+}
+
+func TestSignature_NilSignatureBytes(t *testing.T) {
+ timestamp := time.Now().UTC()
+ keyID, err := ParseKeyID("0011223344556677")
+ require.NoError(t, err)
+
+ sig := Signature{
+ Signature: nil,
+ Timestamp: timestamp,
+ KeyID: keyID,
+ Algorithm: "ed25519",
+ HashAlgo: "blake2s",
+ }
+
+ jsonData, err := json.Marshal(sig)
+ require.NoError(t, err)
+
+ parsed, err := ParseSignature(jsonData)
+ require.NoError(t, err)
+ assert.Nil(t, parsed.Signature)
+}
+
+func TestSignature_LargeSignature(t *testing.T) {
+ timestamp := time.Now().UTC()
+ keyID, err := ParseKeyID("aabbccddeeff0011")
+ require.NoError(t, err)
+
+ // Create a large signature (64 bytes for ed25519)
+ largeSignature := make([]byte, 64)
+ for i := range largeSignature {
+ largeSignature[i] = byte(i)
+ }
+
+ sig := Signature{
+ Signature: largeSignature,
+ Timestamp: timestamp,
+ KeyID: keyID,
+ Algorithm: "ed25519",
+ HashAlgo: "blake2s",
+ }
+
+ jsonData, err := json.Marshal(sig)
+ require.NoError(t, err)
+
+ parsed, err := ParseSignature(jsonData)
+ require.NoError(t, err)
+ assert.Equal(t, largeSignature, parsed.Signature)
+}
+
+func TestSignature_WithDifferentHashAlgorithms(t *testing.T) {
+ tests := []struct {
+ name string
+ hashAlgo string
+ }{
+ {"blake2s", "blake2s"},
+ {"sha512", "sha512"},
+ {"sha256", "sha256"},
+ {"empty", ""},
+ }
+
+ keyID, err := ParseKeyID("1122334455667788")
+ require.NoError(t, err)
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ sig := Signature{
+ Signature: []byte{0x01, 0x02},
+ Timestamp: time.Now().UTC(),
+ KeyID: keyID,
+ Algorithm: "ed25519",
+ HashAlgo: tt.hashAlgo,
+ }
+
+ jsonData, err := json.Marshal(sig)
+ require.NoError(t, err)
+
+ parsed, err := ParseSignature(jsonData)
+ require.NoError(t, err)
+ assert.Equal(t, tt.hashAlgo, parsed.HashAlgo)
+ })
+ }
+}
+
+func TestSignature_TimestampPrecision(t *testing.T) {
+ // Test that timestamp preserves precision through JSON marshaling
+ timestamp := time.Date(2024, 3, 15, 10, 30, 45, 123456789, time.UTC)
+ keyID, err := ParseKeyID("8877665544332211")
+ require.NoError(t, err)
+
+ sig := Signature{
+ Signature: []byte{0xaa, 0xbb},
+ Timestamp: timestamp,
+ KeyID: keyID,
+ Algorithm: "ed25519",
+ HashAlgo: "blake2s",
+ }
+
+ jsonData, err := json.Marshal(sig)
+ require.NoError(t, err)
+
+ parsed, err := ParseSignature(jsonData)
+ require.NoError(t, err)
+
+ // JSON timestamps typically have second or millisecond precision
+ // so we check that at least seconds match
+ assert.Equal(t, timestamp.Unix(), parsed.Timestamp.Unix())
+}
+
+func TestParseSignature_MalformedKeyID(t *testing.T) {
+ // Test with a malformed KeyID field
+ malformedJSON := []byte(`{
+ "signature": "AQID",
+ "timestamp": "2024-01-15T10:30:00Z",
+ "key_id": "invalid_keyid_format",
+ "algorithm": "ed25519",
+ "hash_algo": "blake2s"
+ }`)
+
+ // This should fail since "invalid_keyid_format" is not a valid KeyID
+ sig, err := ParseSignature(malformedJSON)
+ assert.Error(t, err)
+ assert.Nil(t, sig)
+}
+
+func TestParseSignature_InvalidTimestamp(t *testing.T) {
+ // Test with an invalid timestamp format
+ invalidTimestampJSON := []byte(`{
+ "signature": "AQID",
+ "timestamp": "not-a-timestamp",
+ "key_id": "0123456789abcdef",
+ "algorithm": "ed25519",
+ "hash_algo": "blake2s"
+ }`)
+
+ sig, err := ParseSignature(invalidTimestampJSON)
+ assert.Error(t, err)
+ assert.Nil(t, sig)
+}
+
+func TestSignature_ZeroKeyID(t *testing.T) {
+ // Test with a zero KeyID
+ sig := Signature{
+ Signature: []byte{0x01, 0x02, 0x03},
+ Timestamp: time.Now().UTC(),
+ KeyID: KeyID{},
+ Algorithm: "ed25519",
+ HashAlgo: "blake2s",
+ }
+
+ jsonData, err := json.Marshal(sig)
+ require.NoError(t, err)
+
+ parsed, err := ParseSignature(jsonData)
+ require.NoError(t, err)
+ assert.Equal(t, KeyID{}, parsed.KeyID)
+}
+
+func TestParseSignature_ExtraFields(t *testing.T) {
+ // JSON with extra fields that should be ignored
+ jsonWithExtra := []byte(`{
+ "signature": "AQIDBA==",
+ "timestamp": "2024-01-15T10:30:00Z",
+ "key_id": "0123456789abcdef",
+ "algorithm": "ed25519",
+ "hash_algo": "blake2s",
+ "extra_field": "should be ignored",
+ "another_extra": 12345
+ }`)
+
+ sig, err := ParseSignature(jsonWithExtra)
+ require.NoError(t, err)
+ assert.NotNil(t, sig)
+ assert.NotEmpty(t, sig.Signature)
+ assert.Equal(t, "ed25519", sig.Algorithm)
+ assert.Equal(t, "blake2s", sig.HashAlgo)
+}
diff --git a/client/internal/updatemanager/reposign/verify.go b/client/internal/updatemanager/reposign/verify.go
new file mode 100644
index 000000000..0af2a8c9e
--- /dev/null
+++ b/client/internal/updatemanager/reposign/verify.go
@@ -0,0 +1,187 @@
+package reposign
+
+import (
+ "context"
+ "fmt"
+ "net/url"
+ "os"
+ "path/filepath"
+ "strings"
+
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/client/internal/updatemanager/downloader"
+)
+
+const (
+ artifactPubKeysFileName = "artifact-key-pub.pem"
+ artifactPubKeysSigFileName = "artifact-key-pub.pem.sig"
+ revocationFileName = "revocation-list.json"
+ revocationSignFileName = "revocation-list.json.sig"
+
+ keySizeLimit = 5 * 1024 * 1024 //5MB
+ signatureLimit = 1024
+ revocationLimit = 10 * 1024 * 1024
+)
+
+type ArtifactVerify struct {
+ rootKeys []PublicKey
+ keysBaseURL *url.URL
+
+ revocationList *RevocationList
+}
+
+func NewArtifactVerify(keysBaseURL string) (*ArtifactVerify, error) {
+ allKeys, err := loadEmbeddedPublicKeys()
+ if err != nil {
+ return nil, err
+ }
+
+ return newArtifactVerify(keysBaseURL, allKeys)
+}
+
+func newArtifactVerify(keysBaseURL string, allKeys []PublicKey) (*ArtifactVerify, error) {
+ ku, err := url.Parse(keysBaseURL)
+ if err != nil {
+ return nil, fmt.Errorf("invalid keys base URL %q: %v", keysBaseURL, err)
+ }
+
+ a := &ArtifactVerify{
+ rootKeys: allKeys,
+ keysBaseURL: ku,
+ }
+ return a, nil
+}
+
+func (a *ArtifactVerify) Verify(ctx context.Context, version string, artifactFile string) error {
+ version = strings.TrimPrefix(version, "v")
+
+ revocationList, err := a.loadRevocationList(ctx)
+ if err != nil {
+ return fmt.Errorf("failed to load revocation list: %v", err)
+ }
+ a.revocationList = revocationList
+
+ artifactPubKeys, err := a.loadArtifactKeys(ctx)
+ if err != nil {
+ return fmt.Errorf("failed to load artifact keys: %v", err)
+ }
+
+ signature, err := a.loadArtifactSignature(ctx, version, artifactFile)
+ if err != nil {
+ return fmt.Errorf("failed to download signature file for: %s, %v", filepath.Base(artifactFile), err)
+ }
+
+ artifactData, err := os.ReadFile(artifactFile)
+ if err != nil {
+ log.Errorf("failed to read artifact file: %v", err)
+ return fmt.Errorf("failed to read artifact file: %w", err)
+ }
+
+ if err := ValidateArtifact(artifactPubKeys, artifactData, *signature); err != nil {
+ return fmt.Errorf("failed to validate artifact: %v", err)
+ }
+
+ return nil
+}
+
+func (a *ArtifactVerify) loadRevocationList(ctx context.Context) (*RevocationList, error) {
+ downloadURL := a.keysBaseURL.JoinPath("keys", revocationFileName).String()
+ data, err := downloader.DownloadToMemory(ctx, downloadURL, revocationLimit)
+ if err != nil {
+ log.Debugf("failed to download revocation list '%s': %s", downloadURL, err)
+ return nil, err
+ }
+
+ downloadURL = a.keysBaseURL.JoinPath("keys", revocationSignFileName).String()
+ sigData, err := downloader.DownloadToMemory(ctx, downloadURL, signatureLimit)
+ if err != nil {
+ log.Debugf("failed to download revocation list '%s': %s", downloadURL, err)
+ return nil, err
+ }
+
+ signature, err := ParseSignature(sigData)
+ if err != nil {
+ log.Debugf("failed to parse revocation list signature: %s", err)
+ return nil, err
+ }
+
+ return ValidateRevocationList(a.rootKeys, data, *signature)
+}
+
+func (a *ArtifactVerify) loadArtifactKeys(ctx context.Context) ([]PublicKey, error) {
+ downloadURL := a.keysBaseURL.JoinPath("keys", artifactPubKeysFileName).String()
+ log.Debugf("starting downloading artifact keys from: %s", downloadURL)
+ data, err := downloader.DownloadToMemory(ctx, downloadURL, keySizeLimit)
+ if err != nil {
+ log.Debugf("failed to download artifact keys: %s", err)
+ return nil, err
+ }
+
+ downloadURL = a.keysBaseURL.JoinPath("keys", artifactPubKeysSigFileName).String()
+ log.Debugf("start downloading signature of artifact pub key from: %s", downloadURL)
+ sigData, err := downloader.DownloadToMemory(ctx, downloadURL, signatureLimit)
+ if err != nil {
+ log.Debugf("failed to download signature of public keys: %s", err)
+ return nil, err
+ }
+
+ signature, err := ParseSignature(sigData)
+ if err != nil {
+ log.Debugf("failed to parse signature of public keys: %s", err)
+ return nil, err
+ }
+
+ return ValidateArtifactKeys(a.rootKeys, data, *signature, a.revocationList)
+}
+
+func (a *ArtifactVerify) loadArtifactSignature(ctx context.Context, version string, artifactFile string) (*Signature, error) {
+ artifactFile = filepath.Base(artifactFile)
+ downloadURL := a.keysBaseURL.JoinPath("tag", "v"+version, artifactFile+".sig").String()
+ data, err := downloader.DownloadToMemory(ctx, downloadURL, signatureLimit)
+ if err != nil {
+ log.Debugf("failed to download artifact signature: %s", err)
+ return nil, err
+ }
+
+ signature, err := ParseSignature(data)
+ if err != nil {
+ log.Debugf("failed to parse artifact signature: %s", err)
+ return nil, err
+ }
+
+ return signature, nil
+
+}
+
+func loadEmbeddedPublicKeys() ([]PublicKey, error) {
+ files, err := embeddedCerts.ReadDir(embeddedCertsDir)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read embedded certs: %w", err)
+ }
+
+ var allKeys []PublicKey
+ for _, file := range files {
+ if file.IsDir() {
+ continue
+ }
+
+ data, err := embeddedCerts.ReadFile(embeddedCertsDir + "/" + file.Name())
+ if err != nil {
+ return nil, fmt.Errorf("failed to read cert file %s: %w", file.Name(), err)
+ }
+
+ keys, err := parsePublicKeyBundle(data, tagRootPublic)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse cert %s: %w", file.Name(), err)
+ }
+
+ allKeys = append(allKeys, keys...)
+ }
+
+ if len(allKeys) == 0 {
+ return nil, fmt.Errorf("no valid public keys found in embedded certs")
+ }
+
+ return allKeys, nil
+}
diff --git a/client/internal/updatemanager/reposign/verify_test.go b/client/internal/updatemanager/reposign/verify_test.go
new file mode 100644
index 000000000..c29393bad
--- /dev/null
+++ b/client/internal/updatemanager/reposign/verify_test.go
@@ -0,0 +1,528 @@
+package reposign
+
+import (
+ "context"
+ "crypto/ed25519"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// Test ArtifactVerify construction
+
+func TestArtifactVerify_Construction(t *testing.T) {
+ // Generate test root key
+ rootKey, _, rootPubPEM, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ rootPubKey, _, err := parsePublicKey(rootPubPEM, tagRootPublic)
+ require.NoError(t, err)
+
+ keysBaseURL := "http://localhost:8080/artifact-signatures"
+
+ av, err := newArtifactVerify(keysBaseURL, []PublicKey{rootPubKey})
+ require.NoError(t, err)
+
+ assert.NotNil(t, av)
+ assert.NotEmpty(t, av.rootKeys)
+ assert.Equal(t, keysBaseURL, av.keysBaseURL.String())
+
+ // Verify root key structure
+ assert.NotEmpty(t, av.rootKeys[0].Key)
+ assert.Equal(t, rootKey.Metadata.ID, av.rootKeys[0].Metadata.ID)
+ assert.False(t, av.rootKeys[0].Metadata.CreatedAt.IsZero())
+}
+
+func TestArtifactVerify_MultipleRootKeys(t *testing.T) {
+ // Generate multiple test root keys
+ rootKey1, _, rootPubPEM1, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+ rootPubKey1, _, err := parsePublicKey(rootPubPEM1, tagRootPublic)
+ require.NoError(t, err)
+
+ rootKey2, _, rootPubPEM2, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+ rootPubKey2, _, err := parsePublicKey(rootPubPEM2, tagRootPublic)
+ require.NoError(t, err)
+
+ keysBaseURL := "http://localhost:8080/artifact-signatures"
+
+ av, err := newArtifactVerify(keysBaseURL, []PublicKey{rootPubKey1, rootPubKey2})
+ assert.NoError(t, err)
+ assert.Len(t, av.rootKeys, 2)
+ assert.NotEqual(t, rootKey1.Metadata.ID, rootKey2.Metadata.ID)
+}
+
+// Test Verify workflow with mock HTTP server
+
+func TestArtifactVerify_FullWorkflow(t *testing.T) {
+ // Create temporary test directory
+ tempDir := t.TempDir()
+
+ // Step 1: Generate root key
+ rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ // Step 2: Generate artifact key
+ artifactKey, _, artifactPubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+
+ artifactPubKey, err := ParseArtifactPubKey(artifactPubPEM)
+ require.NoError(t, err)
+
+ // Step 3: Create revocation list
+ revocationData, revocationSig, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ // Step 4: Bundle artifact keys
+ artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey})
+ require.NoError(t, err)
+
+ // Step 5: Create test artifact
+ artifactPath := filepath.Join(tempDir, "test-artifact.bin")
+ artifactData := []byte("This is test artifact data for verification")
+ err = os.WriteFile(artifactPath, artifactData, 0644)
+ require.NoError(t, err)
+
+ // Step 6: Sign artifact
+ artifactSigData, err := SignData(*artifactKey, artifactData)
+ require.NoError(t, err)
+
+ // Step 7: Setup mock HTTP server
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/artifact-signatures/keys/" + revocationFileName:
+ _, _ = w.Write(revocationData)
+ case "/artifact-signatures/keys/" + revocationSignFileName:
+ _, _ = w.Write(revocationSig)
+ case "/artifact-signatures/keys/" + artifactPubKeysFileName:
+ _, _ = w.Write(artifactKeysBundle)
+ case "/artifact-signatures/keys/" + artifactPubKeysSigFileName:
+ _, _ = w.Write(artifactKeysSig)
+ case "/artifacts/v1.0.0/test-artifact.bin":
+ _, _ = w.Write(artifactData)
+ case "/artifact-signatures/tag/v1.0.0/test-artifact.bin.sig":
+ _, _ = w.Write(artifactSigData)
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer server.Close()
+
+ // Step 8: Create ArtifactVerify with test root key
+ rootPubKey := PublicKey{
+ Key: rootKey.Key.Public().(ed25519.PublicKey),
+ Metadata: rootKey.Metadata,
+ }
+
+ av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
+ require.NoError(t, err)
+
+ // Step 9: Verify artifact
+ ctx := context.Background()
+ err = av.Verify(ctx, "1.0.0", artifactPath)
+ assert.NoError(t, err)
+}
+
+func TestArtifactVerify_InvalidRevocationList(t *testing.T) {
+ tempDir := t.TempDir()
+ artifactPath := filepath.Join(tempDir, "test.bin")
+ err := os.WriteFile(artifactPath, []byte("test"), 0644)
+ require.NoError(t, err)
+
+ // Setup server with invalid revocation list
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/artifact-signatures/keys/" + revocationFileName:
+ _, _ = w.Write([]byte("invalid data"))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer server.Close()
+
+ rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ rootPubKey := PublicKey{
+ Key: rootKey.Key.Public().(ed25519.PublicKey),
+ Metadata: rootKey.Metadata,
+ }
+
+ av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
+ require.NoError(t, err)
+
+ ctx := context.Background()
+ err = av.Verify(ctx, "1.0.0", artifactPath)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to load revocation list")
+}
+
+func TestArtifactVerify_MissingArtifactFile(t *testing.T) {
+ rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ rootPubKey := PublicKey{
+ Key: rootKey.Key.Public().(ed25519.PublicKey),
+ Metadata: rootKey.Metadata,
+ }
+
+ // Create revocation list
+ revocationData, revocationSig, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ artifactKey, _, artifactPubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+
+ artifactPubKey, err := ParseArtifactPubKey(artifactPubPEM)
+ require.NoError(t, err)
+
+ artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey})
+ require.NoError(t, err)
+
+ // Create signature for non-existent file
+ testData := []byte("test")
+ artifactSigData, err := SignData(*artifactKey, testData)
+ require.NoError(t, err)
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/artifact-signatures/keys/" + revocationFileName:
+ _, _ = w.Write(revocationData)
+ case "/artifact-signatures/keys/" + revocationSignFileName:
+ _, _ = w.Write(revocationSig)
+ case "/artifact-signatures/keys/" + artifactPubKeysFileName:
+ _, _ = w.Write(artifactKeysBundle)
+ case "/artifact-signatures/keys/" + artifactPubKeysSigFileName:
+ _, _ = w.Write(artifactKeysSig)
+ case "/artifact-signatures/tag/v1.0.0/missing.bin.sig":
+ _, _ = w.Write(artifactSigData)
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer server.Close()
+
+ av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
+ require.NoError(t, err)
+
+ ctx := context.Background()
+ err = av.Verify(ctx, "1.0.0", "file.bin")
+ assert.Error(t, err)
+}
+
+func TestArtifactVerify_ServerUnavailable(t *testing.T) {
+ tempDir := t.TempDir()
+ artifactPath := filepath.Join(tempDir, "test.bin")
+ err := os.WriteFile(artifactPath, []byte("test"), 0644)
+ require.NoError(t, err)
+
+ rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ rootPubKey := PublicKey{
+ Key: rootKey.Key.Public().(ed25519.PublicKey),
+ Metadata: rootKey.Metadata,
+ }
+
+ // Use URL that doesn't exist
+ av, err := newArtifactVerify("http://localhost:19999/keys", []PublicKey{rootPubKey})
+ require.NoError(t, err)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
+ defer cancel()
+
+ err = av.Verify(ctx, "1.0.0", artifactPath)
+ assert.Error(t, err)
+}
+
+func TestArtifactVerify_ContextCancellation(t *testing.T) {
+ tempDir := t.TempDir()
+ artifactPath := filepath.Join(tempDir, "test.bin")
+ err := os.WriteFile(artifactPath, []byte("test"), 0644)
+ require.NoError(t, err)
+
+ // Create a server that delays response
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ time.Sleep(500 * time.Millisecond)
+ _, _ = w.Write([]byte("data"))
+ }))
+ defer server.Close()
+
+ rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ rootPubKey := PublicKey{
+ Key: rootKey.Key.Public().(ed25519.PublicKey),
+ Metadata: rootKey.Metadata,
+ }
+
+ av, err := newArtifactVerify(server.URL, []PublicKey{rootPubKey})
+ require.NoError(t, err)
+
+ // Create context that cancels quickly
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
+ defer cancel()
+
+ err = av.Verify(ctx, "1.0.0", artifactPath)
+ assert.Error(t, err)
+}
+
+func TestArtifactVerify_WithRevocation(t *testing.T) {
+ tempDir := t.TempDir()
+
+ // Generate root key
+ rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ // Generate two artifact keys
+ artifactKey1, _, artifactPubPEM1, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+ artifactPubKey1, err := ParseArtifactPubKey(artifactPubPEM1)
+ require.NoError(t, err)
+
+ _, _, artifactPubPEM2, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+ artifactPubKey2, err := ParseArtifactPubKey(artifactPubPEM2)
+ require.NoError(t, err)
+
+ // Create revocation list with first key revoked
+ emptyRevocation, _, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ parsedRevocation, err := ParseRevocationList(emptyRevocation)
+ require.NoError(t, err)
+
+ revocationData, revocationSig, err := ExtendRevocationList(*rootKey, *parsedRevocation, artifactPubKey1.Metadata.ID, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ // Bundle both keys
+ artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey1, artifactPubKey2})
+ require.NoError(t, err)
+
+ // Create artifact signed by revoked key
+ artifactPath := filepath.Join(tempDir, "test.bin")
+ artifactData := []byte("test data")
+ err = os.WriteFile(artifactPath, artifactData, 0644)
+ require.NoError(t, err)
+
+ artifactSigData, err := SignData(*artifactKey1, artifactData)
+ require.NoError(t, err)
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/artifact-signatures/keys/" + revocationFileName:
+ _, _ = w.Write(revocationData)
+ case "/artifact-signatures/keys/" + revocationSignFileName:
+ _, _ = w.Write(revocationSig)
+ case "/artifact-signatures/keys/" + artifactPubKeysFileName:
+ _, _ = w.Write(artifactKeysBundle)
+ case "/artifact-signatures/keys/" + artifactPubKeysSigFileName:
+ _, _ = w.Write(artifactKeysSig)
+ case "/artifact-signatures/tag/v1.0.0/test.bin.sig":
+ _, _ = w.Write(artifactSigData)
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer server.Close()
+
+ rootPubKey := PublicKey{
+ Key: rootKey.Key.Public().(ed25519.PublicKey),
+ Metadata: rootKey.Metadata,
+ }
+
+ av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
+ require.NoError(t, err)
+
+ ctx := context.Background()
+ err = av.Verify(ctx, "1.0.0", artifactPath)
+ // Should fail because the signing key is revoked
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "no signing Key found")
+}
+
+func TestArtifactVerify_ValidWithSecondKey(t *testing.T) {
+ tempDir := t.TempDir()
+
+ // Generate root key
+ rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ // Generate two artifact keys
+ _, _, artifactPubPEM1, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+ artifactPubKey1, err := ParseArtifactPubKey(artifactPubPEM1)
+ require.NoError(t, err)
+
+ artifactKey2, _, artifactPubPEM2, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+ artifactPubKey2, err := ParseArtifactPubKey(artifactPubPEM2)
+ require.NoError(t, err)
+
+ // Create revocation list with first key revoked
+ emptyRevocation, _, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ parsedRevocation, err := ParseRevocationList(emptyRevocation)
+ require.NoError(t, err)
+
+ revocationData, revocationSig, err := ExtendRevocationList(*rootKey, *parsedRevocation, artifactPubKey1.Metadata.ID, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ // Bundle both keys
+ artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey1, artifactPubKey2})
+ require.NoError(t, err)
+
+ // Create artifact signed by second key (not revoked)
+ artifactPath := filepath.Join(tempDir, "test.bin")
+ artifactData := []byte("test data")
+ err = os.WriteFile(artifactPath, artifactData, 0644)
+ require.NoError(t, err)
+
+ artifactSigData, err := SignData(*artifactKey2, artifactData)
+ require.NoError(t, err)
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/artifact-signatures/keys/" + revocationFileName:
+ _, _ = w.Write(revocationData)
+ case "/artifact-signatures/keys/" + revocationSignFileName:
+ _, _ = w.Write(revocationSig)
+ case "/artifact-signatures/keys/" + artifactPubKeysFileName:
+ _, _ = w.Write(artifactKeysBundle)
+ case "/artifact-signatures/keys/" + artifactPubKeysSigFileName:
+ _, _ = w.Write(artifactKeysSig)
+ case "/artifact-signatures/tag/v1.0.0/test.bin.sig":
+ _, _ = w.Write(artifactSigData)
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer server.Close()
+
+ rootPubKey := PublicKey{
+ Key: rootKey.Key.Public().(ed25519.PublicKey),
+ Metadata: rootKey.Metadata,
+ }
+
+ av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
+ require.NoError(t, err)
+
+ ctx := context.Background()
+ err = av.Verify(ctx, "1.0.0", artifactPath)
+ // Should succeed because second key is not revoked
+ assert.NoError(t, err)
+}
+
+func TestArtifactVerify_TamperedArtifact(t *testing.T) {
+ tempDir := t.TempDir()
+
+ // Generate root key and artifact key
+ rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
+ require.NoError(t, err)
+
+ artifactKey, _, artifactPubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
+ require.NoError(t, err)
+ artifactPubKey, err := ParseArtifactPubKey(artifactPubPEM)
+ require.NoError(t, err)
+
+ // Create revocation list
+ revocationData, revocationSig, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration)
+ require.NoError(t, err)
+
+ // Bundle keys
+ artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey})
+ require.NoError(t, err)
+
+ // Sign original data
+ originalData := []byte("original data")
+ artifactSigData, err := SignData(*artifactKey, originalData)
+ require.NoError(t, err)
+
+ // Write tampered data to file
+ artifactPath := filepath.Join(tempDir, "test.bin")
+ tamperedData := []byte("tampered data")
+ err = os.WriteFile(artifactPath, tamperedData, 0644)
+ require.NoError(t, err)
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/artifact-signatures/keys/" + revocationFileName:
+ _, _ = w.Write(revocationData)
+ case "/artifact-signatures/keys/" + revocationSignFileName:
+ _, _ = w.Write(revocationSig)
+ case "/artifact-signatures/keys/" + artifactPubKeysFileName:
+ _, _ = w.Write(artifactKeysBundle)
+ case "/artifact-signatures/keys/" + artifactPubKeysSigFileName:
+ _, _ = w.Write(artifactKeysSig)
+ case "/artifact-signatures/tag/v1.0.0/test.bin.sig":
+ _, _ = w.Write(artifactSigData)
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer server.Close()
+
+ rootPubKey := PublicKey{
+ Key: rootKey.Key.Public().(ed25519.PublicKey),
+ Metadata: rootKey.Metadata,
+ }
+
+ av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
+ require.NoError(t, err)
+
+ ctx := context.Background()
+ err = av.Verify(ctx, "1.0.0", artifactPath)
+ // Should fail because artifact was tampered
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to validate artifact")
+}
+
+// Test URL validation
+
+func TestArtifactVerify_URLParsing(t *testing.T) {
+ tests := []struct {
+ name string
+ keysBaseURL string
+ expectError bool
+ }{
+ {
+ name: "Valid HTTP URL",
+ keysBaseURL: "http://example.com/artifact-signatures",
+ expectError: false,
+ },
+ {
+ name: "Valid HTTPS URL",
+ keysBaseURL: "https://example.com/artifact-signatures",
+ expectError: false,
+ },
+ {
+ name: "URL with port",
+ keysBaseURL: "http://localhost:8080/artifact-signatures",
+ expectError: false,
+ },
+ {
+ name: "Invalid URL",
+ keysBaseURL: "://invalid",
+ expectError: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ _, err := newArtifactVerify(tt.keysBaseURL, nil)
+ if tt.expectError {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
diff --git a/client/internal/updatemanager/update.go b/client/internal/updatemanager/update.go
new file mode 100644
index 000000000..875b50b49
--- /dev/null
+++ b/client/internal/updatemanager/update.go
@@ -0,0 +1,11 @@
+package updatemanager
+
+import v "github.com/hashicorp/go-version"
+
+type UpdateInterface interface {
+ StopWatch()
+ SetDaemonVersion(newVersion string) bool
+ SetOnUpdateListener(updateFn func())
+ LatestVersion() *v.Version
+ StartFetcher()
+}
diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go
index 463c93d57..e901386d9 100644
--- a/client/ios/NetBirdSDK/client.go
+++ b/client/ios/NetBirdSDK/client.go
@@ -75,6 +75,8 @@ type Client struct {
dnsManager dns.IosDnsManager
loginComplete bool
connectClient *internal.ConnectClient
+ // preloadedConfig holds config loaded from JSON (used on tvOS where file writes are blocked)
+ preloadedConfig *profilemanager.Config
}
// NewClient instantiate a new Client
@@ -92,17 +94,44 @@ func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName s
}
}
+// SetConfigFromJSON loads config from a JSON string into memory.
+// This is used on tvOS where file writes to App Group containers are blocked.
+// When set, IsLoginRequired() and Run() will use this preloaded config instead of reading from file.
+func (c *Client) SetConfigFromJSON(jsonStr string) error {
+ cfg, err := profilemanager.ConfigFromJSON(jsonStr)
+ if err != nil {
+ log.Errorf("SetConfigFromJSON: failed to parse config JSON: %v", err)
+ return err
+ }
+ c.preloadedConfig = cfg
+ log.Infof("SetConfigFromJSON: config loaded successfully from JSON")
+ return nil
+}
+
// Run start the internal client. It is a blocker function
func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
exportEnvList(envList)
log.Infof("Starting NetBird client")
log.Debugf("Tunnel uses interface: %s", interfaceName)
- cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
- ConfigPath: c.cfgFile,
- StateFilePath: c.stateFile,
- })
- if err != nil {
- return err
+
+ var cfg *profilemanager.Config
+ var err error
+
+ // Use preloaded config if available (tvOS where file writes are blocked)
+ if c.preloadedConfig != nil {
+ log.Infof("Run: using preloaded config from memory")
+ cfg = c.preloadedConfig
+ } else {
+ log.Infof("Run: loading config from file")
+ // Use DirectUpdateOrCreateConfig to avoid atomic file operations (temp file + rename)
+ // which are blocked by the tvOS sandbox in App Group containers
+ cfg, err = profilemanager.DirectUpdateOrCreateConfig(profilemanager.ConfigInput{
+ ConfigPath: c.cfgFile,
+ StateFilePath: c.stateFile,
+ })
+ if err != nil {
+ return err
+ }
}
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
c.recorder.UpdateRosenpass(cfg.RosenpassEnabled, cfg.RosenpassPermissive)
@@ -120,7 +149,7 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
c.ctxCancelLock.Unlock()
auth := NewAuthWithConfig(ctx, cfg)
- err = auth.Login()
+ err = auth.LoginSync()
if err != nil {
return err
}
@@ -131,7 +160,7 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
c.onHostDnsFn = func([]string) {}
cfg.WgIface = interfaceName
- c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
+ c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
}
@@ -208,14 +237,45 @@ func (c *Client) IsLoginRequired() bool {
defer c.ctxCancelLock.Unlock()
ctx, c.ctxCancel = context.WithCancel(ctxWithValues)
- cfg, _ := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
- ConfigPath: c.cfgFile,
- })
+ var cfg *profilemanager.Config
+ var err error
- needsLogin, _ := internal.IsLoginRequired(ctx, cfg)
+ // Use preloaded config if available (tvOS where file writes are blocked)
+ if c.preloadedConfig != nil {
+ log.Infof("IsLoginRequired: using preloaded config from memory")
+ cfg = c.preloadedConfig
+ } else {
+ log.Infof("IsLoginRequired: loading config from file")
+ // Use DirectUpdateOrCreateConfig to avoid atomic file operations (temp file + rename)
+ // which are blocked by the tvOS sandbox in App Group containers
+ cfg, err = profilemanager.DirectUpdateOrCreateConfig(profilemanager.ConfigInput{
+ ConfigPath: c.cfgFile,
+ })
+ if err != nil {
+ log.Errorf("IsLoginRequired: failed to load config: %v", err)
+ // If we can't load config, assume login is required
+ return true
+ }
+ }
+
+ if cfg == nil {
+ log.Errorf("IsLoginRequired: config is nil")
+ return true
+ }
+
+ needsLogin, err := internal.IsLoginRequired(ctx, cfg)
+ if err != nil {
+ log.Errorf("IsLoginRequired: check failed: %v", err)
+ // If the check fails, assume login is required to be safe
+ return true
+ }
+ log.Infof("IsLoginRequired: needsLogin=%v", needsLogin)
return needsLogin
}
+// loginForMobileAuthTimeout is the timeout for requesting auth info from the server
+const loginForMobileAuthTimeout = 30 * time.Second
+
func (c *Client) LoginForMobile() string {
var ctx context.Context
//nolint
@@ -228,16 +288,26 @@ func (c *Client) LoginForMobile() string {
defer c.ctxCancelLock.Unlock()
ctx, c.ctxCancel = context.WithCancel(ctxWithValues)
- cfg, _ := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
+ // Use DirectUpdateOrCreateConfig to avoid atomic file operations (temp file + rename)
+ // which are blocked by the tvOS sandbox in App Group containers
+ cfg, err := profilemanager.DirectUpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
})
+ if err != nil {
+ log.Errorf("LoginForMobile: failed to load config: %v", err)
+ return fmt.Sprintf("failed to load config: %v", err)
+ }
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, false, "")
if err != nil {
return err.Error()
}
- flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
+ // Use a bounded timeout for the auth info request to prevent indefinite hangs
+ authInfoCtx, authInfoCancel := context.WithTimeout(ctx, loginForMobileAuthTimeout)
+ defer authInfoCancel()
+
+ flowInfo, err := oAuthFlow.RequestAuthInfo(authInfoCtx)
if err != nil {
return err.Error()
}
@@ -249,10 +319,14 @@ func (c *Client) LoginForMobile() string {
defer cancel()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
if err != nil {
+ log.Errorf("LoginForMobile: WaitToken failed: %v", err)
return
}
jwtToken := tokenInfo.GetTokenToUse()
- _ = internal.Login(ctx, cfg, "", jwtToken)
+ if err := internal.Login(ctx, cfg, "", jwtToken); err != nil {
+ log.Errorf("LoginForMobile: Login failed: %v", err)
+ return
+ }
c.loginComplete = true
}()
diff --git a/client/ios/NetBirdSDK/login.go b/client/ios/NetBirdSDK/login.go
index 1c2b38a61..27fdcf5ef 100644
--- a/client/ios/NetBirdSDK/login.go
+++ b/client/ios/NetBirdSDK/login.go
@@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/cmd"
"github.com/netbirdio/netbird/client/internal"
+ "github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/system"
)
@@ -33,7 +34,8 @@ type ErrListener interface {
// URLOpener it is a callback interface. The Open function will be triggered if
// the backend want to show an url for the user
type URLOpener interface {
- Open(string)
+ Open(url string, userCode string)
+ OnLoginSuccess()
}
// Auth can register or login new client
@@ -72,13 +74,32 @@ func NewAuthWithConfig(ctx context.Context, config *profilemanager.Config) *Auth
// SaveConfigIfSSOSupported test the connectivity with the management server by retrieving the server device flow info.
// If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO
// is not supported and returns false without saving the configuration. For other errors return false.
-func (a *Auth) SaveConfigIfSSOSupported() (bool, error) {
+func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
+ if listener == nil {
+ log.Errorf("SaveConfigIfSSOSupported: listener is nil")
+ return
+ }
+ go func() {
+ sso, err := a.saveConfigIfSSOSupported()
+ if err != nil {
+ listener.OnError(err)
+ } else {
+ listener.OnSuccess(sso)
+ }
+ }()
+}
+
+func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
supportsSSO := true
err := a.withBackOff(a.ctx, func() (err error) {
- _, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
+ _, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
- _, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
- if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
+ _, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
+ s, ok := gstatus.FromError(err)
+ if !ok {
+ return err
+ }
+ if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
supportsSSO = false
err = nil
}
@@ -97,12 +118,29 @@ func (a *Auth) SaveConfigIfSSOSupported() (bool, error) {
return false, fmt.Errorf("backoff cycle failed: %v", err)
}
- err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
+ // Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
+ // which are blocked by the tvOS sandbox in App Group containers
+ err = profilemanager.DirectWriteOutConfig(a.cfgPath, a.config)
return true, err
}
// LoginWithSetupKeyAndSaveConfig test the connectivity with the management server with the setup key.
-func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
+func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupKey string, deviceName string) {
+ if resultListener == nil {
+ log.Errorf("LoginWithSetupKeyAndSaveConfig: resultListener is nil")
+ return
+ }
+ go func() {
+ err := a.loginWithSetupKeyAndSaveConfig(setupKey, deviceName)
+ if err != nil {
+ resultListener.OnError(err)
+ } else {
+ resultListener.OnSuccess()
+ }
+ }()
+}
+
+func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
//nolint
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
@@ -118,10 +156,14 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string, deviceName string
return fmt.Errorf("backoff cycle failed: %v", err)
}
- return profilemanager.WriteOutConfig(a.cfgPath, a.config)
+ // Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
+ // which are blocked by the tvOS sandbox in App Group containers
+ return profilemanager.DirectWriteOutConfig(a.cfgPath, a.config)
}
-func (a *Auth) Login() error {
+// LoginSync performs a synchronous login check without UI interaction
+// Used for background VPN connection where user should already be authenticated
+func (a *Auth) LoginSync() error {
var needsLogin bool
// check if we need to generate JWT token
@@ -135,23 +177,142 @@ func (a *Auth) Login() error {
jwtToken := ""
if needsLogin {
- return fmt.Errorf("Not authenticated")
+ return fmt.Errorf("not authenticated")
}
err = a.withBackOff(a.ctx, func() error {
err := internal.Login(a.ctx, a.config, "", jwtToken)
- if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
- return nil
+ if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
+ // PermissionDenied means registration is required or peer is blocked
+ return backoff.Permanent(err)
}
return err
})
+ if err != nil {
+ return fmt.Errorf("login failed: %v", err)
+ }
+
+ return nil
+}
+
+// Login performs interactive login with device authentication support
+// Deprecated: Use LoginWithDeviceName instead to ensure proper device naming on tvOS
+func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, forceDeviceAuth bool) {
+ // Use empty device name - system will use hostname as fallback
+ a.LoginWithDeviceName(resultListener, urlOpener, forceDeviceAuth, "")
+}
+
+// LoginWithDeviceName performs interactive login with device authentication support
+// The deviceName parameter allows specifying a custom device name (required for tvOS)
+func (a *Auth) LoginWithDeviceName(resultListener ErrListener, urlOpener URLOpener, forceDeviceAuth bool, deviceName string) {
+ if resultListener == nil {
+ log.Errorf("LoginWithDeviceName: resultListener is nil")
+ return
+ }
+ if urlOpener == nil {
+ log.Errorf("LoginWithDeviceName: urlOpener is nil")
+ resultListener.OnError(fmt.Errorf("urlOpener is nil"))
+ return
+ }
+ go func() {
+ err := a.login(urlOpener, forceDeviceAuth, deviceName)
+ if err != nil {
+ resultListener.OnError(err)
+ } else {
+ resultListener.OnSuccess()
+ }
+ }()
+}
+
+func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName string) error {
+ var needsLogin bool
+
+ // Create context with device name if provided
+ ctx := a.ctx
+ if deviceName != "" {
+ //nolint:staticcheck
+ ctx = context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
+ }
+
+ // check if we need to generate JWT token
+ err := a.withBackOff(ctx, func() (err error) {
+ needsLogin, err = internal.IsLoginRequired(ctx, a.config)
+ return
+ })
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
}
+ jwtToken := ""
+ if needsLogin {
+ tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, forceDeviceAuth)
+ if err != nil {
+ return fmt.Errorf("interactive sso login failed: %v", err)
+ }
+ jwtToken = tokenInfo.GetTokenToUse()
+ }
+
+ err = a.withBackOff(ctx, func() error {
+ err := internal.Login(ctx, a.config, "", jwtToken)
+ if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
+ // PermissionDenied means registration is required or peer is blocked
+ return backoff.Permanent(err)
+ }
+ return err
+ })
+ if err != nil {
+ return fmt.Errorf("login failed: %v", err)
+ }
+
+ // Save the config before notifying success to ensure persistence completes
+ // before the callback potentially triggers teardown on the Swift side.
+ // Note: This differs from Android which doesn't save config after login.
+ // On iOS/tvOS, we save here because:
+ // 1. The config may have been modified during login (e.g., new tokens)
+ // 2. On tvOS, the Network Extension context may be the only place with
+ // write permissions to the App Group container
+ if a.cfgPath != "" {
+ if err := profilemanager.DirectWriteOutConfig(a.cfgPath, a.config); err != nil {
+ log.Warnf("failed to save config after login: %v", err)
+ }
+ }
+
+ // Notify caller of successful login synchronously before returning
+ urlOpener.OnLoginSuccess()
+
return nil
}
+const authInfoRequestTimeout = 30 * time.Second
+
+func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, forceDeviceAuth bool) (*auth.TokenInfo, error) {
+ oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, forceDeviceAuth, "")
+ if err != nil {
+ return nil, err
+ }
+
+ // Use a bounded timeout for the auth info request to prevent indefinite hangs
+ authInfoCtx, authInfoCancel := context.WithTimeout(a.ctx, authInfoRequestTimeout)
+ defer authInfoCancel()
+
+ flowInfo, err := oAuthFlow.RequestAuthInfo(authInfoCtx)
+ if err != nil {
+ return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
+ }
+
+ urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
+
+ waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
+ waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
+ defer cancel()
+ tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
+ if err != nil {
+ return nil, fmt.Errorf("waiting for browser login failed: %v", err)
+ }
+
+ return &tokenInfo, nil
+}
+
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
return backoff.RetryNotify(
bf,
@@ -160,3 +321,24 @@ func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
})
}
+
+// GetConfigJSON returns the current config as a JSON string.
+// This can be used by the caller to persist the config via alternative storage
+// mechanisms (e.g., UserDefaults on tvOS where file writes are blocked).
+func (a *Auth) GetConfigJSON() (string, error) {
+ if a.config == nil {
+ return "", fmt.Errorf("no config available")
+ }
+ return profilemanager.ConfigToJSON(a.config)
+}
+
+// SetConfigFromJSON loads config from a JSON string.
+// This can be used to restore config from alternative storage mechanisms.
+func (a *Auth) SetConfigFromJSON(jsonStr string) error {
+ cfg, err := profilemanager.ConfigFromJSON(jsonStr)
+ if err != nil {
+ return err
+ }
+ a.config = cfg
+ return nil
+}
diff --git a/client/ios/NetBirdSDK/preferences.go b/client/ios/NetBirdSDK/preferences.go
index 39ae06538..c26a6decd 100644
--- a/client/ios/NetBirdSDK/preferences.go
+++ b/client/ios/NetBirdSDK/preferences.go
@@ -112,6 +112,8 @@ func (p *Preferences) GetRosenpassPermissive() (bool, error) {
// Commit write out the changes into config file
func (p *Preferences) Commit() error {
- _, err := profilemanager.UpdateOrCreateConfig(p.configInput)
+ // Use DirectUpdateOrCreateConfig to avoid atomic file operations (temp file + rename)
+ // which are blocked by the tvOS sandbox in App Group containers
+ _, err := profilemanager.DirectUpdateOrCreateConfig(p.configInput)
return err
}
diff --git a/client/netbird.wxs b/client/netbird.wxs
index ba827debf..03221dd91 100644
--- a/client/netbird.wxs
+++ b/client/netbird.wxs
@@ -51,7 +51,7 @@
-
+
diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go
index 28e8b2d4e..5d56befc7 100644
--- a/client/proto/daemon.pb.go
+++ b/client/proto/daemon.pb.go
@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.6
-// protoc v6.32.1
+// protoc v3.21.12
// source: daemon.proto
package proto
@@ -893,6 +893,7 @@ type UpRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
ProfileName *string `protobuf:"bytes,1,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"`
Username *string `protobuf:"bytes,2,opt,name=username,proto3,oneof" json:"username,omitempty"`
+ AutoUpdate *bool `protobuf:"varint,3,opt,name=autoUpdate,proto3,oneof" json:"autoUpdate,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -941,6 +942,13 @@ func (x *UpRequest) GetUsername() string {
return ""
}
+func (x *UpRequest) GetAutoUpdate() bool {
+ if x != nil && x.AutoUpdate != nil {
+ return *x.AutoUpdate
+ }
+ return false
+}
+
type UpResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
@@ -2005,6 +2013,7 @@ type SSHSessionInfo struct {
RemoteAddress string `protobuf:"bytes,2,opt,name=remoteAddress,proto3" json:"remoteAddress,omitempty"`
Command string `protobuf:"bytes,3,opt,name=command,proto3" json:"command,omitempty"`
JwtUsername string `protobuf:"bytes,4,opt,name=jwtUsername,proto3" json:"jwtUsername,omitempty"`
+ PortForwards []string `protobuf:"bytes,5,rep,name=portForwards,proto3" json:"portForwards,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -2067,6 +2076,13 @@ func (x *SSHSessionInfo) GetJwtUsername() string {
return ""
}
+func (x *SSHSessionInfo) GetPortForwards() []string {
+ if x != nil {
+ return x.PortForwards
+ }
+ return nil
+}
+
// SSHServerState contains the latest state of the SSH server
type SSHServerState struct {
state protoimpl.MessageState `protogen:"open.v1"`
@@ -5356,6 +5372,94 @@ func (x *WaitJWTTokenResponse) GetExpiresIn() int64 {
return 0
}
+type InstallerResultRequest struct {
+ state protoimpl.MessageState `protogen:"open.v1"`
+ unknownFields protoimpl.UnknownFields
+ sizeCache protoimpl.SizeCache
+}
+
+func (x *InstallerResultRequest) Reset() {
+ *x = InstallerResultRequest{}
+ mi := &file_daemon_proto_msgTypes[79]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+}
+
+func (x *InstallerResultRequest) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*InstallerResultRequest) ProtoMessage() {}
+
+func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message {
+ mi := &file_daemon_proto_msgTypes[79]
+ if x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use InstallerResultRequest.ProtoReflect.Descriptor instead.
+func (*InstallerResultRequest) Descriptor() ([]byte, []int) {
+ return file_daemon_proto_rawDescGZIP(), []int{79}
+}
+
+type InstallerResultResponse struct {
+ state protoimpl.MessageState `protogen:"open.v1"`
+ Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"`
+ ErrorMsg string `protobuf:"bytes,2,opt,name=errorMsg,proto3" json:"errorMsg,omitempty"`
+ unknownFields protoimpl.UnknownFields
+ sizeCache protoimpl.SizeCache
+}
+
+func (x *InstallerResultResponse) Reset() {
+ *x = InstallerResultResponse{}
+ mi := &file_daemon_proto_msgTypes[80]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+}
+
+func (x *InstallerResultResponse) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*InstallerResultResponse) ProtoMessage() {}
+
+func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message {
+ mi := &file_daemon_proto_msgTypes[80]
+ if x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use InstallerResultResponse.ProtoReflect.Descriptor instead.
+func (*InstallerResultResponse) Descriptor() ([]byte, []int) {
+ return file_daemon_proto_rawDescGZIP(), []int{80}
+}
+
+func (x *InstallerResultResponse) GetSuccess() bool {
+ if x != nil {
+ return x.Success
+ }
+ return false
+}
+
+func (x *InstallerResultResponse) GetErrorMsg() string {
+ if x != nil {
+ return x.ErrorMsg
+ }
+ return ""
+}
+
type PortInfo_Range struct {
state protoimpl.MessageState `protogen:"open.v1"`
Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"`
@@ -5366,7 +5470,7 @@ type PortInfo_Range struct {
func (x *PortInfo_Range) Reset() {
*x = PortInfo_Range{}
- mi := &file_daemon_proto_msgTypes[80]
+ mi := &file_daemon_proto_msgTypes[82]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -5378,7 +5482,7 @@ func (x *PortInfo_Range) String() string {
func (*PortInfo_Range) ProtoMessage() {}
func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
- mi := &file_daemon_proto_msgTypes[80]
+ mi := &file_daemon_proto_msgTypes[82]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -5502,12 +5606,16 @@ const file_daemon_proto_rawDesc = "" +
"\buserCode\x18\x01 \x01(\tR\buserCode\x12\x1a\n" +
"\bhostname\x18\x02 \x01(\tR\bhostname\",\n" +
"\x14WaitSSOLoginResponse\x12\x14\n" +
- "\x05email\x18\x01 \x01(\tR\x05email\"p\n" +
+ "\x05email\x18\x01 \x01(\tR\x05email\"\xa4\x01\n" +
"\tUpRequest\x12%\n" +
"\vprofileName\x18\x01 \x01(\tH\x00R\vprofileName\x88\x01\x01\x12\x1f\n" +
- "\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" +
+ "\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01\x12#\n" +
+ "\n" +
+ "autoUpdate\x18\x03 \x01(\bH\x02R\n" +
+ "autoUpdate\x88\x01\x01B\x0e\n" +
"\f_profileNameB\v\n" +
- "\t_username\"\f\n" +
+ "\t_usernameB\r\n" +
+ "\v_autoUpdate\"\f\n" +
"\n" +
"UpResponse\"\xa1\x01\n" +
"\rStatusRequest\x12,\n" +
@@ -5606,12 +5714,13 @@ const file_daemon_proto_rawDesc = "" +
"\aservers\x18\x01 \x03(\tR\aservers\x12\x18\n" +
"\adomains\x18\x02 \x03(\tR\adomains\x12\x18\n" +
"\aenabled\x18\x03 \x01(\bR\aenabled\x12\x14\n" +
- "\x05error\x18\x04 \x01(\tR\x05error\"\x8e\x01\n" +
+ "\x05error\x18\x04 \x01(\tR\x05error\"\xb2\x01\n" +
"\x0eSSHSessionInfo\x12\x1a\n" +
"\busername\x18\x01 \x01(\tR\busername\x12$\n" +
"\rremoteAddress\x18\x02 \x01(\tR\rremoteAddress\x12\x18\n" +
"\acommand\x18\x03 \x01(\tR\acommand\x12 \n" +
- "\vjwtUsername\x18\x04 \x01(\tR\vjwtUsername\"^\n" +
+ "\vjwtUsername\x18\x04 \x01(\tR\vjwtUsername\x12\"\n" +
+ "\fportForwards\x18\x05 \x03(\tR\fportForwards\"^\n" +
"\x0eSSHServerState\x12\x18\n" +
"\aenabled\x18\x01 \x01(\bR\aenabled\x122\n" +
"\bsessions\x18\x02 \x03(\v2\x16.daemon.SSHSessionInfoR\bsessions\"\xaf\x04\n" +
@@ -5893,7 +6002,11 @@ const file_daemon_proto_rawDesc = "" +
"\x14WaitJWTTokenResponse\x12\x14\n" +
"\x05token\x18\x01 \x01(\tR\x05token\x12\x1c\n" +
"\ttokenType\x18\x02 \x01(\tR\ttokenType\x12\x1c\n" +
- "\texpiresIn\x18\x03 \x01(\x03R\texpiresIn*b\n" +
+ "\texpiresIn\x18\x03 \x01(\x03R\texpiresIn\"\x18\n" +
+ "\x16InstallerResultRequest\"O\n" +
+ "\x17InstallerResultResponse\x12\x18\n" +
+ "\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" +
+ "\berrorMsg\x18\x02 \x01(\tR\berrorMsg*b\n" +
"\bLogLevel\x12\v\n" +
"\aUNKNOWN\x10\x00\x12\t\n" +
"\x05PANIC\x10\x01\x12\t\n" +
@@ -5902,7 +6015,7 @@ const file_daemon_proto_rawDesc = "" +
"\x04WARN\x10\x04\x12\b\n" +
"\x04INFO\x10\x05\x12\t\n" +
"\x05DEBUG\x10\x06\x12\t\n" +
- "\x05TRACE\x10\a2\xdb\x12\n" +
+ "\x05TRACE\x10\a2\xb4\x13\n" +
"\rDaemonService\x126\n" +
"\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" +
"\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" +
@@ -5938,7 +6051,8 @@ const file_daemon_proto_rawDesc = "" +
"\x11GetPeerSSHHostKey\x12 .daemon.GetPeerSSHHostKeyRequest\x1a!.daemon.GetPeerSSHHostKeyResponse\"\x00\x12Q\n" +
"\x0eRequestJWTAuth\x12\x1d.daemon.RequestJWTAuthRequest\x1a\x1e.daemon.RequestJWTAuthResponse\"\x00\x12K\n" +
"\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00\x12N\n" +
- "\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00B\bZ\x06/protob\x06proto3"
+ "\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00\x12W\n" +
+ "\x12GetInstallerResult\x12\x1e.daemon.InstallerResultRequest\x1a\x1f.daemon.InstallerResultResponse\"\x00B\bZ\x06/protob\x06proto3"
var (
file_daemon_proto_rawDescOnce sync.Once
@@ -5953,7 +6067,7 @@ func file_daemon_proto_rawDescGZIP() []byte {
}
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 4)
-var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 82)
+var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 84)
var file_daemon_proto_goTypes = []any{
(LogLevel)(0), // 0: daemon.LogLevel
(OSLifecycleRequest_CycleType)(0), // 1: daemon.OSLifecycleRequest.CycleType
@@ -6038,19 +6152,21 @@ var file_daemon_proto_goTypes = []any{
(*RequestJWTAuthResponse)(nil), // 80: daemon.RequestJWTAuthResponse
(*WaitJWTTokenRequest)(nil), // 81: daemon.WaitJWTTokenRequest
(*WaitJWTTokenResponse)(nil), // 82: daemon.WaitJWTTokenResponse
- nil, // 83: daemon.Network.ResolvedIPsEntry
- (*PortInfo_Range)(nil), // 84: daemon.PortInfo.Range
- nil, // 85: daemon.SystemEvent.MetadataEntry
- (*durationpb.Duration)(nil), // 86: google.protobuf.Duration
- (*timestamppb.Timestamp)(nil), // 87: google.protobuf.Timestamp
+ (*InstallerResultRequest)(nil), // 83: daemon.InstallerResultRequest
+ (*InstallerResultResponse)(nil), // 84: daemon.InstallerResultResponse
+ nil, // 85: daemon.Network.ResolvedIPsEntry
+ (*PortInfo_Range)(nil), // 86: daemon.PortInfo.Range
+ nil, // 87: daemon.SystemEvent.MetadataEntry
+ (*durationpb.Duration)(nil), // 88: google.protobuf.Duration
+ (*timestamppb.Timestamp)(nil), // 89: google.protobuf.Timestamp
}
var file_daemon_proto_depIdxs = []int32{
1, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType
- 86, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
+ 88, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
27, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
- 87, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
- 87, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
- 86, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration
+ 89, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
+ 89, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
+ 88, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration
25, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo
22, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
21, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState
@@ -6061,8 +6177,8 @@ var file_daemon_proto_depIdxs = []int32{
57, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent
26, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState
33, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
- 83, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
- 84, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
+ 85, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
+ 86, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
34, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
34, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
35, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
@@ -6073,10 +6189,10 @@ var file_daemon_proto_depIdxs = []int32{
54, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
2, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
3, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
- 87, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
- 85, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
+ 89, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
+ 87, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
57, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
- 86, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
+ 88, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
70, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
32, // 33: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
7, // 34: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
@@ -6111,40 +6227,42 @@ var file_daemon_proto_depIdxs = []int32{
79, // 63: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
81, // 64: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
5, // 65: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest
- 8, // 66: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
- 10, // 67: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
- 12, // 68: daemon.DaemonService.Up:output_type -> daemon.UpResponse
- 14, // 69: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
- 16, // 70: daemon.DaemonService.Down:output_type -> daemon.DownResponse
- 18, // 71: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
- 29, // 72: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
- 31, // 73: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
- 31, // 74: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
- 36, // 75: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
- 38, // 76: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
- 40, // 77: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
- 42, // 78: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
- 45, // 79: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
- 47, // 80: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
- 49, // 81: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
- 51, // 82: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
- 55, // 83: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
- 57, // 84: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
- 59, // 85: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
- 61, // 86: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
- 63, // 87: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
- 65, // 88: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
- 67, // 89: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
- 69, // 90: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
- 72, // 91: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
- 74, // 92: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
- 76, // 93: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
- 78, // 94: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
- 80, // 95: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
- 82, // 96: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
- 6, // 97: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse
- 66, // [66:98] is the sub-list for method output_type
- 34, // [34:66] is the sub-list for method input_type
+ 83, // 66: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
+ 8, // 67: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
+ 10, // 68: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
+ 12, // 69: daemon.DaemonService.Up:output_type -> daemon.UpResponse
+ 14, // 70: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
+ 16, // 71: daemon.DaemonService.Down:output_type -> daemon.DownResponse
+ 18, // 72: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
+ 29, // 73: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
+ 31, // 74: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
+ 31, // 75: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
+ 36, // 76: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
+ 38, // 77: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
+ 40, // 78: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
+ 42, // 79: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
+ 45, // 80: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
+ 47, // 81: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
+ 49, // 82: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
+ 51, // 83: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
+ 55, // 84: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
+ 57, // 85: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
+ 59, // 86: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
+ 61, // 87: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
+ 63, // 88: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
+ 65, // 89: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
+ 67, // 90: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
+ 69, // 91: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
+ 72, // 92: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
+ 74, // 93: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
+ 76, // 94: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
+ 78, // 95: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
+ 80, // 96: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
+ 82, // 97: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
+ 6, // 98: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse
+ 84, // 99: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
+ 67, // [67:100] is the sub-list for method output_type
+ 34, // [34:67] is the sub-list for method input_type
34, // [34:34] is the sub-list for extension type_name
34, // [34:34] is the sub-list for extension extendee
0, // [0:34] is the sub-list for field type_name
@@ -6174,7 +6292,7 @@ func file_daemon_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)),
NumEnums: 4,
- NumMessages: 82,
+ NumMessages: 84,
NumExtensions: 0,
NumServices: 1,
},
diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto
index 3dfd3da8d..b75ca821a 100644
--- a/client/proto/daemon.proto
+++ b/client/proto/daemon.proto
@@ -95,6 +95,8 @@ service DaemonService {
rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {}
rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {}
+
+ rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {}
}
@@ -215,6 +217,7 @@ message WaitSSOLoginResponse {
message UpRequest {
optional string profileName = 1;
optional string username = 2;
+ optional bool autoUpdate = 3;
}
message UpResponse {}
@@ -369,6 +372,7 @@ message SSHSessionInfo {
string remoteAddress = 2;
string command = 3;
string jwtUsername = 4;
+ repeated string portForwards = 5;
}
// SSHServerState contains the latest state of the SSH server
@@ -772,3 +776,11 @@ message WaitJWTTokenResponse {
// expiration time in seconds
int64 expiresIn = 3;
}
+
+message InstallerResultRequest {
+}
+
+message InstallerResultResponse {
+ bool success = 1;
+ string errorMsg = 2;
+}
diff --git a/client/proto/daemon_grpc.pb.go b/client/proto/daemon_grpc.pb.go
index 6b01309b7..fdabb1879 100644
--- a/client/proto/daemon_grpc.pb.go
+++ b/client/proto/daemon_grpc.pb.go
@@ -71,6 +71,7 @@ type DaemonServiceClient interface {
// WaitJWTToken waits for JWT authentication completion
WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error)
NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error)
+ GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error)
}
type daemonServiceClient struct {
@@ -392,6 +393,15 @@ func (c *daemonServiceClient) NotifyOSLifecycle(ctx context.Context, in *OSLifec
return out, nil
}
+func (c *daemonServiceClient) GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error) {
+ out := new(InstallerResultResponse)
+ err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetInstallerResult", in, out, opts...)
+ if err != nil {
+ return nil, err
+ }
+ return out, nil
+}
+
// DaemonServiceServer is the server API for DaemonService service.
// All implementations must embed UnimplementedDaemonServiceServer
// for forward compatibility
@@ -449,6 +459,7 @@ type DaemonServiceServer interface {
// WaitJWTToken waits for JWT authentication completion
WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error)
NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error)
+ GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error)
mustEmbedUnimplementedDaemonServiceServer()
}
@@ -552,6 +563,9 @@ func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTo
func (UnimplementedDaemonServiceServer) NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method NotifyOSLifecycle not implemented")
}
+func (UnimplementedDaemonServiceServer) GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error) {
+ return nil, status.Errorf(codes.Unimplemented, "method GetInstallerResult not implemented")
+}
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
@@ -1144,6 +1158,24 @@ func _DaemonService_NotifyOSLifecycle_Handler(srv interface{}, ctx context.Conte
return interceptor(ctx, in, info, handler)
}
+func _DaemonService_GetInstallerResult_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
+ in := new(InstallerResultRequest)
+ if err := dec(in); err != nil {
+ return nil, err
+ }
+ if interceptor == nil {
+ return srv.(DaemonServiceServer).GetInstallerResult(ctx, in)
+ }
+ info := &grpc.UnaryServerInfo{
+ Server: srv,
+ FullMethod: "/daemon.DaemonService/GetInstallerResult",
+ }
+ handler := func(ctx context.Context, req interface{}) (interface{}, error) {
+ return srv.(DaemonServiceServer).GetInstallerResult(ctx, req.(*InstallerResultRequest))
+ }
+ return interceptor(ctx, in, info, handler)
+}
+
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@@ -1275,6 +1307,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "NotifyOSLifecycle",
Handler: _DaemonService_NotifyOSLifecycle_Handler,
},
+ {
+ MethodName: "GetInstallerResult",
+ Handler: _DaemonService_GetInstallerResult_Handler,
+ },
},
Streams: []grpc.StreamDesc{
{
diff --git a/client/proto/generate.sh b/client/proto/generate.sh
index f9a2c3750..e659cef90 100755
--- a/client/proto/generate.sh
+++ b/client/proto/generate.sh
@@ -14,4 +14,4 @@ cd "$script_path"
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.36.6
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
protoc -I ./ ./daemon.proto --go_out=../ --go-grpc_out=../ --experimental_allow_proto3_optional
-cd "$old_pwd"
\ No newline at end of file
+cd "$old_pwd"
diff --git a/client/server/server.go b/client/server/server.go
index d33595115..7b6c4e98c 100644
--- a/client/server/server.go
+++ b/client/server/server.go
@@ -145,10 +145,10 @@ func (s *Server) Start() error {
ctx, cancel := context.WithCancel(s.rootCtx)
s.actCancel = cancel
- // set the default config if not exists
- if err := s.setDefaultConfigIfNotExists(ctx); err != nil {
- log.Errorf("failed to set default config: %v", err)
- return fmt.Errorf("failed to set default config: %w", err)
+ // copy old default config
+ _, err = s.profileManager.CopyDefaultProfileIfNotExists()
+ if err != nil && !errors.Is(err, profilemanager.ErrorOldDefaultConfigNotFound) {
+ return err
}
activeProf, err := s.profileManager.GetActiveProfileState()
@@ -156,23 +156,11 @@ func (s *Server) Start() error {
return fmt.Errorf("failed to get active profile state: %w", err)
}
- config, err := s.getConfig(activeProf)
+ config, existingConfig, err := s.getConfig(activeProf)
if err != nil {
log.Errorf("failed to get active profile config: %v", err)
- if err := s.profileManager.SetActiveProfileState(&profilemanager.ActiveProfileState{
- Name: "default",
- Username: "",
- }); err != nil {
- log.Errorf("failed to set active profile state: %v", err)
- return fmt.Errorf("failed to set active profile state: %w", err)
- }
-
- config, err = profilemanager.GetConfig(s.profileManager.DefaultProfilePath())
- if err != nil {
- log.Errorf("failed to get default profile config: %v", err)
- return fmt.Errorf("failed to get default profile config: %w", err)
- }
+ return err
}
s.config = config
@@ -186,44 +174,27 @@ func (s *Server) Start() error {
}
if config.DisableAutoConnect {
+ state.Set(internal.StatusIdle)
+ return nil
+ }
+
+ if !existingConfig {
+ log.Warnf("not trying to connect when configuration was just created")
+ state.Set(internal.StatusNeedsLogin)
return nil
}
s.clientRunning = true
s.clientRunningChan = make(chan struct{})
s.clientGiveUpChan = make(chan struct{})
- go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
- return nil
-}
-
-func (s *Server) setDefaultConfigIfNotExists(ctx context.Context) error {
- ok, err := s.profileManager.CopyDefaultProfileIfNotExists()
- if err != nil {
- if err := s.profileManager.CreateDefaultProfile(); err != nil {
- log.Errorf("failed to create default profile: %v", err)
- return fmt.Errorf("failed to create default profile: %w", err)
- }
-
- if err := s.profileManager.SetActiveProfileState(&profilemanager.ActiveProfileState{
- Name: "default",
- Username: "",
- }); err != nil {
- log.Errorf("failed to set active profile state: %v", err)
- return fmt.Errorf("failed to set active profile state: %w", err)
- }
- }
- if ok {
- state := internal.CtxGetState(ctx)
- state.Set(internal.StatusNeedsLogin)
- }
-
+ go s.connectWithRetryRuns(ctx, config, s.statusRecorder, false, s.clientRunningChan, s.clientGiveUpChan)
return nil
}
// connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional
// mechanism to keep the client connected even when the connection is lost.
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
-func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) {
+func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, doInitialAutoUpdate bool, runningChan chan struct{}, giveUpChan chan struct{}) {
defer func() {
s.mutex.Lock()
s.clientRunning = false
@@ -231,7 +202,7 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
}()
if s.config.DisableAutoConnect {
- if err := s.connect(ctx, s.config, s.statusRecorder, runningChan); err != nil {
+ if err := s.connect(ctx, s.config, s.statusRecorder, doInitialAutoUpdate, runningChan); err != nil {
log.Debugf("run client connection exited with error: %v", err)
}
log.Tracef("client connection exited")
@@ -260,7 +231,8 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
}()
runOperation := func() error {
- err := s.connect(ctx, profileConfig, statusRecorder, runningChan)
+ err := s.connect(ctx, profileConfig, statusRecorder, doInitialAutoUpdate, runningChan)
+ doInitialAutoUpdate = false
if err != nil {
log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
return err
@@ -486,7 +458,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
s.mutex.Unlock()
- config, err := s.getConfig(activeProf)
+ config, _, err := s.getConfig(activeProf)
if err != nil {
log.Errorf("failed to get active profile config: %v", err)
return nil, fmt.Errorf("failed to get active profile config: %w", err)
@@ -715,7 +687,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
log.Infof("active profile: %s for %s", activeProf.Name, activeProf.Username)
- config, err := s.getConfig(activeProf)
+ config, _, err := s.getConfig(activeProf)
if err != nil {
log.Errorf("failed to get active profile config: %v", err)
return nil, fmt.Errorf("failed to get active profile config: %w", err)
@@ -728,7 +700,12 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
s.clientRunning = true
s.clientRunningChan = make(chan struct{})
s.clientGiveUpChan = make(chan struct{})
- go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
+
+ var doAutoUpdate bool
+ if msg != nil && msg.AutoUpdate != nil && *msg.AutoUpdate {
+ doAutoUpdate = true
+ }
+ go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, doAutoUpdate, s.clientRunningChan, s.clientGiveUpChan)
return s.waitForUp(callerCtx)
}
@@ -805,7 +782,7 @@ func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfi
log.Errorf("failed to get active profile state: %v", err)
return nil, fmt.Errorf("failed to get active profile state: %w", err)
}
- config, err := s.getConfig(activeProf)
+ config, _, err := s.getConfig(activeProf)
if err != nil {
log.Errorf("failed to get default profile config: %v", err)
return nil, fmt.Errorf("failed to get default profile config: %w", err)
@@ -902,7 +879,7 @@ func (s *Server) handleActiveProfileLogout(ctx context.Context) (*proto.LogoutRe
return nil, gstatus.Errorf(codes.FailedPrecondition, "failed to get active profile state: %v", err)
}
- config, err := s.getConfig(activeProf)
+ config, _, err := s.getConfig(activeProf)
if err != nil {
return nil, gstatus.Errorf(codes.FailedPrecondition, "not logged in")
}
@@ -926,19 +903,24 @@ func (s *Server) handleActiveProfileLogout(ctx context.Context) (*proto.LogoutRe
return &proto.LogoutResponse{}, nil
}
-// getConfig loads the config from the active profile
-func (s *Server) getConfig(activeProf *profilemanager.ActiveProfileState) (*profilemanager.Config, error) {
+// GetConfig reads config file and returns Config and whether the config file already existed. Errors out if it does not exist
+func (s *Server) getConfig(activeProf *profilemanager.ActiveProfileState) (*profilemanager.Config, bool, error) {
cfgPath, err := activeProf.FilePath()
if err != nil {
- return nil, fmt.Errorf("failed to get active profile file path: %w", err)
+ return nil, false, fmt.Errorf("failed to get active profile file path: %w", err)
}
- config, err := profilemanager.GetConfig(cfgPath)
+ _, err = os.Stat(cfgPath)
+ configExisted := !os.IsNotExist(err)
+
+ log.Infof("active profile config existed: %t, err %v", configExisted, err)
+
+ config, err := profilemanager.ReadConfig(cfgPath)
if err != nil {
- return nil, fmt.Errorf("failed to get config: %w", err)
+ return nil, false, fmt.Errorf("failed to get config: %w", err)
}
- return config, nil
+ return config, configExisted, nil
}
func (s *Server) canRemoveProfile(profileName string) error {
@@ -1122,6 +1104,7 @@ func (s *Server) getSSHServerState() *proto.SSHServerState {
RemoteAddress: session.RemoteAddress,
Command: session.Command,
JwtUsername: session.JWTUsername,
+ PortForwards: session.PortForwards,
})
}
@@ -1539,9 +1522,9 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest)
return features, nil
}
-func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) error {
+func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, doInitialAutoUpdate bool, runningChan chan struct{}) error {
log.Tracef("running client connection")
- s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
+ s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder, doInitialAutoUpdate)
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
if err := s.connectClient.Run(runningChan); err != nil {
return err
diff --git a/client/server/server_test.go b/client/server/server_test.go
index 5f28a2664..1ed115769 100644
--- a/client/server/server_test.go
+++ b/client/server/server_test.go
@@ -112,7 +112,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
t.Setenv(maxRetryTimeVar, "5s")
t.Setenv(retryMultiplierVar, "1")
- s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil)
+ s.connectWithRetryRuns(ctx, config, s.statusRecorder, false, nil, nil)
if counter < 3 {
t.Fatalf("expected counter > 2, got %d", counter)
}
@@ -326,7 +326,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
if err != nil {
return nil, "", err
}
- mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController)
+ mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
if err != nil {
return nil, "", err
}
diff --git a/client/server/updateresult.go b/client/server/updateresult.go
new file mode 100644
index 000000000..8e00d5062
--- /dev/null
+++ b/client/server/updateresult.go
@@ -0,0 +1,30 @@
+package server
+
+import (
+ "context"
+
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/client/internal/updatemanager/installer"
+ "github.com/netbirdio/netbird/client/proto"
+)
+
+func (s *Server) GetInstallerResult(ctx context.Context, _ *proto.InstallerResultRequest) (*proto.InstallerResultResponse, error) {
+ inst := installer.New()
+ dir := inst.TempDir()
+
+ rh := installer.NewResultHandler(dir)
+ result, err := rh.Watch(ctx)
+ if err != nil {
+ log.Errorf("failed to watch update result: %v", err)
+ return &proto.InstallerResultResponse{
+ Success: false,
+ ErrorMsg: err.Error(),
+ }, nil
+ }
+
+ return &proto.InstallerResultResponse{
+ Success: result.Success,
+ ErrorMsg: result.Error,
+ }, nil
+}
diff --git a/client/ssh/auth/auth.go b/client/ssh/auth/auth.go
new file mode 100644
index 000000000..079282fdc
--- /dev/null
+++ b/client/ssh/auth/auth.go
@@ -0,0 +1,177 @@
+package auth
+
+import (
+ "errors"
+ "fmt"
+ "sync"
+
+ log "github.com/sirupsen/logrus"
+
+ sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
+)
+
+const (
+ // DefaultUserIDClaim is the default JWT claim used to extract user IDs
+ DefaultUserIDClaim = "sub"
+ // Wildcard is a special user ID that matches all users
+ Wildcard = "*"
+)
+
+var (
+ ErrEmptyUserID = errors.New("JWT user ID is empty")
+ ErrUserNotAuthorized = errors.New("user is not authorized to access this peer")
+ ErrNoMachineUserMapping = errors.New("no authorization mapping for OS user")
+ ErrUserNotMappedToOSUser = errors.New("user is not authorized to login as OS user")
+)
+
+// Authorizer handles SSH fine-grained access control authorization
+type Authorizer struct {
+ // UserIDClaim is the JWT claim to extract the user ID from
+ userIDClaim string
+
+ // authorizedUsers is a list of hashed user IDs authorized to access this peer
+ authorizedUsers []sshuserhash.UserIDHash
+
+ // machineUsers maps OS login usernames to lists of authorized user indexes
+ machineUsers map[string][]uint32
+
+ // mu protects the list of users
+ mu sync.RWMutex
+}
+
+// Config contains configuration for the SSH authorizer
+type Config struct {
+ // UserIDClaim is the JWT claim to extract the user ID from (e.g., "sub", "email")
+ UserIDClaim string
+
+ // AuthorizedUsers is a list of hashed user IDs (FNV-1a 64-bit) authorized to access this peer
+ AuthorizedUsers []sshuserhash.UserIDHash
+
+ // MachineUsers maps OS login usernames to indexes in AuthorizedUsers
+ // If a user wants to login as a specific OS user, their index must be in the corresponding list
+ MachineUsers map[string][]uint32
+}
+
+// NewAuthorizer creates a new SSH authorizer with empty configuration
+func NewAuthorizer() *Authorizer {
+ a := &Authorizer{
+ userIDClaim: DefaultUserIDClaim,
+ machineUsers: make(map[string][]uint32),
+ }
+
+ return a
+}
+
+// Update updates the authorizer configuration with new values
+func (a *Authorizer) Update(config *Config) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ if config == nil {
+ // Clear authorization
+ a.userIDClaim = DefaultUserIDClaim
+ a.authorizedUsers = []sshuserhash.UserIDHash{}
+ a.machineUsers = make(map[string][]uint32)
+ log.Info("SSH authorization cleared")
+ return
+ }
+
+ userIDClaim := config.UserIDClaim
+ if userIDClaim == "" {
+ userIDClaim = DefaultUserIDClaim
+ }
+ a.userIDClaim = userIDClaim
+
+ // Store authorized users list
+ a.authorizedUsers = config.AuthorizedUsers
+
+ // Store machine users mapping
+ machineUsers := make(map[string][]uint32)
+ for osUser, indexes := range config.MachineUsers {
+ if len(indexes) > 0 {
+ machineUsers[osUser] = indexes
+ }
+ }
+ a.machineUsers = machineUsers
+
+ log.Debugf("SSH auth: updated with %d authorized users, %d machine user mappings",
+ len(config.AuthorizedUsers), len(machineUsers))
+}
+
+// Authorize validates if a user is authorized to login as the specified OS user.
+// Returns a success message describing how authorization was granted, or an error.
+func (a *Authorizer) Authorize(jwtUserID, osUsername string) (string, error) {
+ if jwtUserID == "" {
+ return "", fmt.Errorf("JWT user ID is empty for OS user %q: %w", osUsername, ErrEmptyUserID)
+ }
+
+ // Hash the JWT user ID for comparison
+ hashedUserID, err := sshuserhash.HashUserID(jwtUserID)
+ if err != nil {
+ return "", fmt.Errorf("hash user ID %q for OS user %q: %w", jwtUserID, osUsername, err)
+ }
+
+ a.mu.RLock()
+ defer a.mu.RUnlock()
+
+ // Find the index of this user in the authorized list
+ userIndex, found := a.findUserIndex(hashedUserID)
+ if !found {
+ return "", fmt.Errorf("user %q (hash: %s) not in authorized list for OS user %q: %w", jwtUserID, hashedUserID, osUsername, ErrUserNotAuthorized)
+ }
+
+ return a.checkMachineUserMapping(jwtUserID, osUsername, userIndex)
+}
+
+// checkMachineUserMapping validates if a user's index is authorized for the specified OS user
+// Checks wildcard mapping first, then specific OS user mappings
+func (a *Authorizer) checkMachineUserMapping(jwtUserID, osUsername string, userIndex int) (string, error) {
+ // If wildcard exists and user's index is in the wildcard list, allow access to any OS user
+ if wildcardIndexes, hasWildcard := a.machineUsers[Wildcard]; hasWildcard {
+ if a.isIndexInList(uint32(userIndex), wildcardIndexes) {
+ return fmt.Sprintf("granted via wildcard (index: %d)", userIndex), nil
+ }
+ }
+
+ // Check for specific OS username mapping
+ allowedIndexes, hasMachineUserMapping := a.machineUsers[osUsername]
+ if !hasMachineUserMapping {
+ // No mapping for this OS user - deny by default (fail closed)
+ return "", fmt.Errorf("no machine user mapping for OS user %q (JWT user: %s): %w", osUsername, jwtUserID, ErrNoMachineUserMapping)
+ }
+
+ // Check if user's index is in the allowed indexes for this specific OS user
+ if !a.isIndexInList(uint32(userIndex), allowedIndexes) {
+ return "", fmt.Errorf("user %q not mapped to OS user %q (index: %d): %w", jwtUserID, osUsername, userIndex, ErrUserNotMappedToOSUser)
+ }
+
+ return fmt.Sprintf("granted (index: %d)", userIndex), nil
+}
+
+// GetUserIDClaim returns the JWT claim name used to extract user IDs
+func (a *Authorizer) GetUserIDClaim() string {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
+ return a.userIDClaim
+}
+
+// findUserIndex finds the index of a hashed user ID in the authorized users list
+// Returns the index and true if found, 0 and false if not found
+func (a *Authorizer) findUserIndex(hashedUserID sshuserhash.UserIDHash) (int, bool) {
+ for i, id := range a.authorizedUsers {
+ if id == hashedUserID {
+ return i, true
+ }
+ }
+ return 0, false
+}
+
+// isIndexInList checks if an index exists in a list of indexes
+func (a *Authorizer) isIndexInList(index uint32, indexes []uint32) bool {
+ for _, idx := range indexes {
+ if idx == index {
+ return true
+ }
+ }
+ return false
+}
diff --git a/client/ssh/auth/auth_test.go b/client/ssh/auth/auth_test.go
new file mode 100644
index 000000000..fa27b72e8
--- /dev/null
+++ b/client/ssh/auth/auth_test.go
@@ -0,0 +1,612 @@
+package auth
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/netbirdio/netbird/shared/sshauth"
+)
+
+func TestAuthorizer_Authorize_UserNotInList(t *testing.T) {
+ authorizer := NewAuthorizer()
+
+ // Set up authorized users list with one user
+ authorizedUserHash, err := sshauth.HashUserID("authorized-user")
+ require.NoError(t, err)
+
+ config := &Config{
+ UserIDClaim: DefaultUserIDClaim,
+ AuthorizedUsers: []sshauth.UserIDHash{authorizedUserHash},
+ MachineUsers: map[string][]uint32{},
+ }
+ authorizer.Update(config)
+
+ // Try to authorize a different user
+ _, err = authorizer.Authorize("unauthorized-user", "root")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrUserNotAuthorized)
+}
+
+func TestAuthorizer_Authorize_UserInList_NoMachineUserRestrictions(t *testing.T) {
+ authorizer := NewAuthorizer()
+
+ user1Hash, err := sshauth.HashUserID("user1")
+ require.NoError(t, err)
+ user2Hash, err := sshauth.HashUserID("user2")
+ require.NoError(t, err)
+
+ config := &Config{
+ UserIDClaim: DefaultUserIDClaim,
+ AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
+ MachineUsers: map[string][]uint32{}, // Empty = deny all (fail closed)
+ }
+ authorizer.Update(config)
+
+ // All attempts should fail when no machine user mappings exist (fail closed)
+ _, err = authorizer.Authorize("user1", "root")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrNoMachineUserMapping)
+
+ _, err = authorizer.Authorize("user2", "admin")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrNoMachineUserMapping)
+
+ _, err = authorizer.Authorize("user1", "postgres")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrNoMachineUserMapping)
+}
+
+func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Allowed(t *testing.T) {
+ authorizer := NewAuthorizer()
+
+ user1Hash, err := sshauth.HashUserID("user1")
+ require.NoError(t, err)
+ user2Hash, err := sshauth.HashUserID("user2")
+ require.NoError(t, err)
+ user3Hash, err := sshauth.HashUserID("user3")
+ require.NoError(t, err)
+
+ config := &Config{
+ UserIDClaim: DefaultUserIDClaim,
+ AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash},
+ MachineUsers: map[string][]uint32{
+ "root": {0, 1}, // user1 and user2 can access root
+ "postgres": {1, 2}, // user2 and user3 can access postgres
+ "admin": {0}, // only user1 can access admin
+ },
+ }
+ authorizer.Update(config)
+
+ // user1 (index 0) should access root and admin
+ _, err = authorizer.Authorize("user1", "root")
+ assert.NoError(t, err)
+
+ _, err = authorizer.Authorize("user1", "admin")
+ assert.NoError(t, err)
+
+ // user2 (index 1) should access root and postgres
+ _, err = authorizer.Authorize("user2", "root")
+ assert.NoError(t, err)
+
+ _, err = authorizer.Authorize("user2", "postgres")
+ assert.NoError(t, err)
+
+ // user3 (index 2) should access postgres
+ _, err = authorizer.Authorize("user3", "postgres")
+ assert.NoError(t, err)
+}
+
+func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Denied(t *testing.T) {
+ authorizer := NewAuthorizer()
+
+ // Set up authorized users list
+ user1Hash, err := sshauth.HashUserID("user1")
+ require.NoError(t, err)
+ user2Hash, err := sshauth.HashUserID("user2")
+ require.NoError(t, err)
+ user3Hash, err := sshauth.HashUserID("user3")
+ require.NoError(t, err)
+
+ config := &Config{
+ UserIDClaim: DefaultUserIDClaim,
+ AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash},
+ MachineUsers: map[string][]uint32{
+ "root": {0, 1}, // user1 and user2 can access root
+ "postgres": {1, 2}, // user2 and user3 can access postgres
+ "admin": {0}, // only user1 can access admin
+ },
+ }
+ authorizer.Update(config)
+
+ // user1 (index 0) should NOT access postgres
+ _, err = authorizer.Authorize("user1", "postgres")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
+
+ // user2 (index 1) should NOT access admin
+ _, err = authorizer.Authorize("user2", "admin")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
+
+ // user3 (index 2) should NOT access root
+ _, err = authorizer.Authorize("user3", "root")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
+
+ // user3 (index 2) should NOT access admin
+ _, err = authorizer.Authorize("user3", "admin")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
+}
+
+func TestAuthorizer_Authorize_UserInList_OSUserNotInMapping(t *testing.T) {
+ authorizer := NewAuthorizer()
+
+ // Set up authorized users list
+ user1Hash, err := sshauth.HashUserID("user1")
+ require.NoError(t, err)
+
+ config := &Config{
+ UserIDClaim: DefaultUserIDClaim,
+ AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
+ MachineUsers: map[string][]uint32{
+ "root": {0}, // only root is mapped
+ },
+ }
+ authorizer.Update(config)
+
+ // user1 should NOT access an unmapped OS user (fail closed)
+ _, err = authorizer.Authorize("user1", "postgres")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrNoMachineUserMapping)
+}
+
+func TestAuthorizer_Authorize_EmptyJWTUserID(t *testing.T) {
+ authorizer := NewAuthorizer()
+
+ // Set up authorized users list
+ user1Hash, err := sshauth.HashUserID("user1")
+ require.NoError(t, err)
+
+ config := &Config{
+ UserIDClaim: DefaultUserIDClaim,
+ AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
+ MachineUsers: map[string][]uint32{},
+ }
+ authorizer.Update(config)
+
+ // Empty user ID should fail
+ _, err = authorizer.Authorize("", "root")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrEmptyUserID)
+}
+
+func TestAuthorizer_Authorize_MultipleUsersInList(t *testing.T) {
+ authorizer := NewAuthorizer()
+
+ // Set up multiple authorized users
+ userHashes := make([]sshauth.UserIDHash, 10)
+ for i := 0; i < 10; i++ {
+ hash, err := sshauth.HashUserID("user" + string(rune('0'+i)))
+ require.NoError(t, err)
+ userHashes[i] = hash
+ }
+
+ // Create machine user mapping for all users
+ rootIndexes := make([]uint32, 10)
+ for i := 0; i < 10; i++ {
+ rootIndexes[i] = uint32(i)
+ }
+
+ config := &Config{
+ UserIDClaim: DefaultUserIDClaim,
+ AuthorizedUsers: userHashes,
+ MachineUsers: map[string][]uint32{
+ "root": rootIndexes,
+ },
+ }
+ authorizer.Update(config)
+
+ // All users should be authorized for root
+ for i := 0; i < 10; i++ {
+ _, err := authorizer.Authorize("user"+string(rune('0'+i)), "root")
+ assert.NoError(t, err, "user%d should be authorized", i)
+ }
+
+ // User not in list should fail
+ _, err := authorizer.Authorize("unknown-user", "root")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrUserNotAuthorized)
+}
+
+func TestAuthorizer_Update_ClearsConfiguration(t *testing.T) {
+ authorizer := NewAuthorizer()
+
+ // Set up initial configuration
+ user1Hash, err := sshauth.HashUserID("user1")
+ require.NoError(t, err)
+
+ config := &Config{
+ UserIDClaim: DefaultUserIDClaim,
+ AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
+ MachineUsers: map[string][]uint32{"root": {0}},
+ }
+ authorizer.Update(config)
+
+ // user1 should be authorized
+ _, err = authorizer.Authorize("user1", "root")
+ assert.NoError(t, err)
+
+ // Clear configuration
+ authorizer.Update(nil)
+
+ // user1 should no longer be authorized
+ _, err = authorizer.Authorize("user1", "root")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrUserNotAuthorized)
+}
+
+func TestAuthorizer_Update_EmptyMachineUsersListEntries(t *testing.T) {
+ authorizer := NewAuthorizer()
+
+ user1Hash, err := sshauth.HashUserID("user1")
+ require.NoError(t, err)
+
+ // Machine users with empty index lists should be filtered out
+ config := &Config{
+ UserIDClaim: DefaultUserIDClaim,
+ AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
+ MachineUsers: map[string][]uint32{
+ "root": {0},
+ "postgres": {}, // empty list - should be filtered out
+ "admin": nil, // nil list - should be filtered out
+ },
+ }
+ authorizer.Update(config)
+
+ // root should work
+ _, err = authorizer.Authorize("user1", "root")
+ assert.NoError(t, err)
+
+ // postgres should fail (no mapping)
+ _, err = authorizer.Authorize("user1", "postgres")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrNoMachineUserMapping)
+
+ // admin should fail (no mapping)
+ _, err = authorizer.Authorize("user1", "admin")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrNoMachineUserMapping)
+}
+
+func TestAuthorizer_CustomUserIDClaim(t *testing.T) {
+ authorizer := NewAuthorizer()
+
+ // Set up with custom user ID claim
+ user1Hash, err := sshauth.HashUserID("user@example.com")
+ require.NoError(t, err)
+
+ config := &Config{
+ UserIDClaim: "email",
+ AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
+ MachineUsers: map[string][]uint32{
+ "root": {0},
+ },
+ }
+ authorizer.Update(config)
+
+ // Verify the custom claim is set
+ assert.Equal(t, "email", authorizer.GetUserIDClaim())
+
+ // Authorize with email as user ID
+ _, err = authorizer.Authorize("user@example.com", "root")
+ assert.NoError(t, err)
+}
+
+func TestAuthorizer_DefaultUserIDClaim(t *testing.T) {
+ authorizer := NewAuthorizer()
+
+ // Verify default claim
+ assert.Equal(t, DefaultUserIDClaim, authorizer.GetUserIDClaim())
+ assert.Equal(t, "sub", authorizer.GetUserIDClaim())
+
+ // Set up with empty user ID claim (should use default)
+ user1Hash, err := sshauth.HashUserID("user1")
+ require.NoError(t, err)
+
+ config := &Config{
+ UserIDClaim: "", // empty - should use default
+ AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
+ MachineUsers: map[string][]uint32{},
+ }
+ authorizer.Update(config)
+
+ // Should fall back to default
+ assert.Equal(t, DefaultUserIDClaim, authorizer.GetUserIDClaim())
+}
+
+func TestAuthorizer_MachineUserMapping_LargeIndexes(t *testing.T) {
+ authorizer := NewAuthorizer()
+
+ // Create a large authorized users list
+ const numUsers = 1000
+ userHashes := make([]sshauth.UserIDHash, numUsers)
+ for i := 0; i < numUsers; i++ {
+ hash, err := sshauth.HashUserID("user" + string(rune(i)))
+ require.NoError(t, err)
+ userHashes[i] = hash
+ }
+
+ config := &Config{
+ UserIDClaim: DefaultUserIDClaim,
+ AuthorizedUsers: userHashes,
+ MachineUsers: map[string][]uint32{
+ "root": {0, 500, 999}, // first, middle, and last user
+ },
+ }
+ authorizer.Update(config)
+
+ // First user should have access
+ _, err := authorizer.Authorize("user"+string(rune(0)), "root")
+ assert.NoError(t, err)
+
+ // Middle user should have access
+ _, err = authorizer.Authorize("user"+string(rune(500)), "root")
+ assert.NoError(t, err)
+
+ // Last user should have access
+ _, err = authorizer.Authorize("user"+string(rune(999)), "root")
+ assert.NoError(t, err)
+
+ // User not in mapping should NOT have access
+ _, err = authorizer.Authorize("user"+string(rune(100)), "root")
+ assert.Error(t, err)
+}
+
+func TestAuthorizer_ConcurrentAuthorization(t *testing.T) {
+ authorizer := NewAuthorizer()
+
+ // Set up authorized users
+ user1Hash, err := sshauth.HashUserID("user1")
+ require.NoError(t, err)
+ user2Hash, err := sshauth.HashUserID("user2")
+ require.NoError(t, err)
+
+ config := &Config{
+ UserIDClaim: DefaultUserIDClaim,
+ AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
+ MachineUsers: map[string][]uint32{
+ "root": {0, 1},
+ },
+ }
+ authorizer.Update(config)
+
+ // Test concurrent authorization calls (should be safe to read concurrently)
+ const numGoroutines = 100
+ errChan := make(chan error, numGoroutines)
+
+ for i := 0; i < numGoroutines; i++ {
+ go func(idx int) {
+ user := "user1"
+ if idx%2 == 0 {
+ user = "user2"
+ }
+ _, err := authorizer.Authorize(user, "root")
+ errChan <- err
+ }(i)
+ }
+
+ // Wait for all goroutines to complete and collect errors
+ for i := 0; i < numGoroutines; i++ {
+ err := <-errChan
+ assert.NoError(t, err)
+ }
+}
+
+func TestAuthorizer_Wildcard_AllowsAllAuthorizedUsers(t *testing.T) {
+ authorizer := NewAuthorizer()
+
+ user1Hash, err := sshauth.HashUserID("user1")
+ require.NoError(t, err)
+ user2Hash, err := sshauth.HashUserID("user2")
+ require.NoError(t, err)
+ user3Hash, err := sshauth.HashUserID("user3")
+ require.NoError(t, err)
+
+ // Configure with wildcard - all authorized users can access any OS user
+ config := &Config{
+ UserIDClaim: DefaultUserIDClaim,
+ AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash},
+ MachineUsers: map[string][]uint32{
+ "*": {0, 1, 2}, // wildcard with all user indexes
+ },
+ }
+ authorizer.Update(config)
+
+ // All authorized users should be able to access any OS user
+ _, err = authorizer.Authorize("user1", "root")
+ assert.NoError(t, err)
+
+ _, err = authorizer.Authorize("user2", "postgres")
+ assert.NoError(t, err)
+
+ _, err = authorizer.Authorize("user3", "admin")
+ assert.NoError(t, err)
+
+ _, err = authorizer.Authorize("user1", "ubuntu")
+ assert.NoError(t, err)
+
+ _, err = authorizer.Authorize("user2", "nginx")
+ assert.NoError(t, err)
+
+ _, err = authorizer.Authorize("user3", "docker")
+ assert.NoError(t, err)
+}
+
+func TestAuthorizer_Wildcard_UnauthorizedUserStillDenied(t *testing.T) {
+ authorizer := NewAuthorizer()
+
+ user1Hash, err := sshauth.HashUserID("user1")
+ require.NoError(t, err)
+
+ // Configure with wildcard
+ config := &Config{
+ UserIDClaim: DefaultUserIDClaim,
+ AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
+ MachineUsers: map[string][]uint32{
+ "*": {0},
+ },
+ }
+ authorizer.Update(config)
+
+ // user1 should have access
+ _, err = authorizer.Authorize("user1", "root")
+ assert.NoError(t, err)
+
+ // Unauthorized user should still be denied even with wildcard
+ _, err = authorizer.Authorize("unauthorized-user", "root")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrUserNotAuthorized)
+}
+
+func TestAuthorizer_Wildcard_TakesPrecedenceOverSpecificMappings(t *testing.T) {
+ authorizer := NewAuthorizer()
+
+ user1Hash, err := sshauth.HashUserID("user1")
+ require.NoError(t, err)
+ user2Hash, err := sshauth.HashUserID("user2")
+ require.NoError(t, err)
+
+ // Configure with both wildcard and specific mappings
+ // Wildcard takes precedence for users in the wildcard index list
+ config := &Config{
+ UserIDClaim: DefaultUserIDClaim,
+ AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
+ MachineUsers: map[string][]uint32{
+ "*": {0, 1}, // wildcard for both users
+ "root": {0}, // specific mapping that would normally restrict to user1 only
+ },
+ }
+ authorizer.Update(config)
+
+ // Both users should be able to access root via wildcard (takes precedence over specific mapping)
+ _, err = authorizer.Authorize("user1", "root")
+ assert.NoError(t, err)
+
+ _, err = authorizer.Authorize("user2", "root")
+ assert.NoError(t, err)
+
+ // Both users should be able to access any other OS user via wildcard
+ _, err = authorizer.Authorize("user1", "postgres")
+ assert.NoError(t, err)
+
+ _, err = authorizer.Authorize("user2", "admin")
+ assert.NoError(t, err)
+}
+
+func TestAuthorizer_NoWildcard_SpecificMappingsOnly(t *testing.T) {
+ authorizer := NewAuthorizer()
+
+ user1Hash, err := sshauth.HashUserID("user1")
+ require.NoError(t, err)
+ user2Hash, err := sshauth.HashUserID("user2")
+ require.NoError(t, err)
+
+ // Configure WITHOUT wildcard - only specific mappings
+ config := &Config{
+ UserIDClaim: DefaultUserIDClaim,
+ AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
+ MachineUsers: map[string][]uint32{
+ "root": {0}, // only user1
+ "postgres": {1}, // only user2
+ },
+ }
+ authorizer.Update(config)
+
+ // user1 can access root
+ _, err = authorizer.Authorize("user1", "root")
+ assert.NoError(t, err)
+
+ // user2 can access postgres
+ _, err = authorizer.Authorize("user2", "postgres")
+ assert.NoError(t, err)
+
+ // user1 cannot access postgres
+ _, err = authorizer.Authorize("user1", "postgres")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
+
+ // user2 cannot access root
+ _, err = authorizer.Authorize("user2", "root")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
+
+ // Neither can access unmapped OS users
+ _, err = authorizer.Authorize("user1", "admin")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrNoMachineUserMapping)
+
+ _, err = authorizer.Authorize("user2", "admin")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrNoMachineUserMapping)
+}
+
+func TestAuthorizer_Wildcard_WithPartialIndexes_AllowsAllUsers(t *testing.T) {
+ // This test covers the scenario where wildcard exists with limited indexes.
+ // Only users whose indexes are in the wildcard list can access any OS user via wildcard.
+ // Other users can only access OS users they are explicitly mapped to.
+ authorizer := NewAuthorizer()
+
+ // Create two authorized user hashes (simulating the base64-encoded hashes in the config)
+ wasmHash, err := sshauth.HashUserID("wasm")
+ require.NoError(t, err)
+ user2Hash, err := sshauth.HashUserID("user2")
+ require.NoError(t, err)
+
+ // Configure with wildcard having only index 0, and specific mappings for other OS users
+ config := &Config{
+ UserIDClaim: "sub",
+ AuthorizedUsers: []sshauth.UserIDHash{wasmHash, user2Hash},
+ MachineUsers: map[string][]uint32{
+ "*": {0}, // wildcard with only index 0 - only wasm has wildcard access
+ "alice": {1}, // specific mapping for user2
+ "bob": {1}, // specific mapping for user2
+ },
+ }
+ authorizer.Update(config)
+
+ // wasm (index 0) should access any OS user via wildcard
+ _, err = authorizer.Authorize("wasm", "root")
+ assert.NoError(t, err, "wasm should access root via wildcard")
+
+ _, err = authorizer.Authorize("wasm", "alice")
+ assert.NoError(t, err, "wasm should access alice via wildcard")
+
+ _, err = authorizer.Authorize("wasm", "bob")
+ assert.NoError(t, err, "wasm should access bob via wildcard")
+
+ _, err = authorizer.Authorize("wasm", "postgres")
+ assert.NoError(t, err, "wasm should access postgres via wildcard")
+
+ // user2 (index 1) should only access alice and bob (explicitly mapped), NOT root or postgres
+ _, err = authorizer.Authorize("user2", "alice")
+ assert.NoError(t, err, "user2 should access alice via explicit mapping")
+
+ _, err = authorizer.Authorize("user2", "bob")
+ assert.NoError(t, err, "user2 should access bob via explicit mapping")
+
+ _, err = authorizer.Authorize("user2", "root")
+ assert.Error(t, err, "user2 should NOT access root (not in wildcard indexes)")
+ assert.ErrorIs(t, err, ErrNoMachineUserMapping)
+
+ _, err = authorizer.Authorize("user2", "postgres")
+ assert.Error(t, err, "user2 should NOT access postgres (not explicitly mapped)")
+ assert.ErrorIs(t, err, ErrNoMachineUserMapping)
+
+ // Unauthorized user should still be denied
+ _, err = authorizer.Authorize("user3", "root")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrUserNotAuthorized, "unauthorized user should be denied")
+}
diff --git a/client/ssh/client/client.go b/client/ssh/client/client.go
index aab222093..342da7303 100644
--- a/client/ssh/client/client.go
+++ b/client/ssh/client/client.go
@@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
- "io"
"net"
"os"
"path/filepath"
@@ -551,14 +550,15 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str
func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
defer func() {
if err := localConn.Close(); err != nil {
- log.Debugf("local connection close error: %v", err)
+ log.Debugf("local port forwarding: close local connection: %v", err)
}
}()
channel, err := c.client.Dial("tcp", remoteAddr)
if err != nil {
- if strings.Contains(err.Error(), "administratively prohibited") {
- _, _ = fmt.Fprintf(os.Stderr, "channel open failed: administratively prohibited: port forwarding is disabled\n")
+ var openErr *ssh.OpenChannelError
+ if errors.As(err, &openErr) && openErr.Reason == ssh.Prohibited {
+ _, _ = fmt.Fprintf(os.Stderr, "channel open failed: port forwarding is disabled\n")
} else {
log.Debugf("local port forwarding to %s failed: %v", remoteAddr, err)
}
@@ -566,19 +566,11 @@ func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
}
defer func() {
if err := channel.Close(); err != nil {
- log.Debugf("remote channel close error: %v", err)
+ log.Debugf("local port forwarding: close remote channel: %v", err)
}
}()
- go func() {
- if _, err := io.Copy(channel, localConn); err != nil {
- log.Debugf("local forward copy error (local->remote): %v", err)
- }
- }()
-
- if _, err := io.Copy(localConn, channel); err != nil {
- log.Debugf("local forward copy error (remote->local): %v", err)
- }
+ nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel)
}
// RemotePortForward sets up remote port forwarding, binding on remote and forwarding to localAddr
@@ -633,7 +625,7 @@ func (c *Client) sendTCPIPForwardRequest(req tcpipForwardMsg) error {
return fmt.Errorf("send tcpip-forward request: %w", err)
}
if !ok {
- return fmt.Errorf("remote port forwarding denied by server (check if --allow-ssh-remote-port-forwarding is enabled)")
+ return fmt.Errorf("remote port forwarding denied by server")
}
return nil
}
@@ -676,7 +668,7 @@ func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr st
}
defer func() {
if err := channel.Close(); err != nil {
- log.Debugf("remote channel close error: %v", err)
+ log.Debugf("remote port forwarding: close remote channel: %v", err)
}
}()
@@ -688,19 +680,11 @@ func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr st
}
defer func() {
if err := localConn.Close(); err != nil {
- log.Debugf("local connection close error: %v", err)
+ log.Debugf("remote port forwarding: close local connection: %v", err)
}
}()
- go func() {
- if _, err := io.Copy(localConn, channel); err != nil {
- log.Debugf("remote forward copy error (remote->local): %v", err)
- }
- }()
-
- if _, err := io.Copy(channel, localConn); err != nil {
- log.Debugf("remote forward copy error (local->remote): %v", err)
- }
+ nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel)
}
// tcpipForwardMsg represents the structure for tcpip-forward requests
diff --git a/client/ssh/common.go b/client/ssh/common.go
index 6574437b5..f6aec5f9c 100644
--- a/client/ssh/common.go
+++ b/client/ssh/common.go
@@ -193,3 +193,64 @@ func buildAddressList(hostname string, remote net.Addr) []string {
}
return addresses
}
+
+// BidirectionalCopy copies data bidirectionally between two io.ReadWriter connections.
+// It waits for both directions to complete before returning.
+// The caller is responsible for closing the connections.
+func BidirectionalCopy(logger *log.Entry, rw1, rw2 io.ReadWriter) {
+ done := make(chan struct{}, 2)
+
+ go func() {
+ if _, err := io.Copy(rw2, rw1); err != nil && !isExpectedCopyError(err) {
+ logger.Debugf("copy error (1->2): %v", err)
+ }
+ done <- struct{}{}
+ }()
+
+ go func() {
+ if _, err := io.Copy(rw1, rw2); err != nil && !isExpectedCopyError(err) {
+ logger.Debugf("copy error (2->1): %v", err)
+ }
+ done <- struct{}{}
+ }()
+
+ <-done
+ <-done
+}
+
+func isExpectedCopyError(err error) bool {
+ return errors.Is(err, io.EOF) || errors.Is(err, context.Canceled)
+}
+
+// BidirectionalCopyWithContext copies data bidirectionally between two io.ReadWriteCloser connections.
+// It waits for both directions to complete or for context cancellation before returning.
+// Both connections are closed when the function returns.
+func BidirectionalCopyWithContext(logger *log.Entry, ctx context.Context, conn1, conn2 io.ReadWriteCloser) {
+ done := make(chan struct{}, 2)
+
+ go func() {
+ if _, err := io.Copy(conn2, conn1); err != nil && !isExpectedCopyError(err) {
+ logger.Debugf("copy error (1->2): %v", err)
+ }
+ done <- struct{}{}
+ }()
+
+ go func() {
+ if _, err := io.Copy(conn1, conn2); err != nil && !isExpectedCopyError(err) {
+ logger.Debugf("copy error (2->1): %v", err)
+ }
+ done <- struct{}{}
+ }()
+
+ select {
+ case <-ctx.Done():
+ case <-done:
+ select {
+ case <-ctx.Done():
+ case <-done:
+ }
+ }
+
+ _ = conn1.Close()
+ _ = conn2.Close()
+}
diff --git a/client/ssh/proxy/proxy.go b/client/ssh/proxy/proxy.go
index 4e807e33c..cb1c36e13 100644
--- a/client/ssh/proxy/proxy.go
+++ b/client/ssh/proxy/proxy.go
@@ -2,6 +2,7 @@ package proxy
import (
"context"
+ "encoding/binary"
"errors"
"fmt"
"io"
@@ -42,6 +43,14 @@ type SSHProxy struct {
conn *grpc.ClientConn
daemonClient proto.DaemonServiceClient
browserOpener func(string) error
+
+ mu sync.RWMutex
+ backendClient *cryptossh.Client
+ // jwtToken is set once in runProxySSHServer before any handlers are called,
+ // so concurrent access is safe without additional synchronization.
+ jwtToken string
+
+ forwardedChannelsOnce sync.Once
}
func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer, browserOpener func(string) error) (*SSHProxy, error) {
@@ -63,6 +72,17 @@ func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer, browse
}
func (p *SSHProxy) Close() error {
+ p.mu.Lock()
+ backendClient := p.backendClient
+ p.backendClient = nil
+ p.mu.Unlock()
+
+ if backendClient != nil {
+ if err := backendClient.Close(); err != nil {
+ log.Debugf("close backend client: %v", err)
+ }
+ }
+
if p.conn != nil {
return p.conn.Close()
}
@@ -77,16 +97,16 @@ func (p *SSHProxy) Connect(ctx context.Context) error {
return fmt.Errorf(jwtAuthErrorMsg, err)
}
- return p.runProxySSHServer(ctx, jwtToken)
+ log.Debugf("JWT authentication successful, starting proxy to %s:%d", p.targetHost, p.targetPort)
+ return p.runProxySSHServer(jwtToken)
}
-func (p *SSHProxy) runProxySSHServer(ctx context.Context, jwtToken string) error {
+func (p *SSHProxy) runProxySSHServer(jwtToken string) error {
+ p.jwtToken = jwtToken
serverVersion := fmt.Sprintf("%s-%s", detection.ProxyIdentifier, version.NetbirdVersion())
sshServer := &ssh.Server{
- Handler: func(s ssh.Session) {
- p.handleSSHSession(ctx, s, jwtToken)
- },
+ Handler: p.handleSSHSession,
ChannelHandlers: map[string]ssh.ChannelHandler{
"session": ssh.DefaultSessionHandler,
"direct-tcpip": p.directTCPIPHandler,
@@ -119,15 +139,20 @@ func (p *SSHProxy) runProxySSHServer(ctx context.Context, jwtToken string) error
return nil
}
-func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jwtToken string) {
- targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
+func (p *SSHProxy) handleSSHSession(session ssh.Session) {
+ ptyReq, winCh, isPty := session.Pty()
+ hasCommand := len(session.Command()) > 0
- sshClient, err := p.dialBackend(ctx, targetAddr, session.User(), jwtToken)
+ sshClient, err := p.getOrCreateBackendClient(session.Context(), session.User())
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "SSH connection to NetBird server failed: %v\n", err)
return
}
- defer func() { _ = sshClient.Close() }()
+
+ if !isPty && !hasCommand {
+ p.handleNonInteractiveSession(session, sshClient)
+ return
+ }
serverSession, err := sshClient.NewSession()
if err != nil {
@@ -140,7 +165,6 @@ func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jw
serverSession.Stdout = session
serverSession.Stderr = session.Stderr()
- ptyReq, winCh, isPty := session.Pty()
if isPty {
if err := serverSession.RequestPty(ptyReq.Term, ptyReq.Window.Width, ptyReq.Window.Height, nil); err != nil {
log.Debugf("PTY request to backend: %v", err)
@@ -155,7 +179,7 @@ func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jw
}()
}
- if len(session.Command()) > 0 {
+ if hasCommand {
if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil {
log.Debugf("run command: %v", err)
p.handleProxyExitCode(session, err)
@@ -176,12 +200,29 @@ func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jw
func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) {
var exitErr *cryptossh.ExitError
if errors.As(err, &exitErr) {
- if exitErr := session.Exit(exitErr.ExitStatus()); exitErr != nil {
- log.Debugf("set exit status: %v", exitErr)
+ if err := session.Exit(exitErr.ExitStatus()); err != nil {
+ log.Debugf("set exit status: %v", err)
}
}
}
+func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *cryptossh.Client) {
+ // Create a backend session to mirror the client's session request.
+ // This keeps the connection alive on the server side while port forwarding channels operate.
+ serverSession, err := sshClient.NewSession()
+ if err != nil {
+ _, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err)
+ return
+ }
+ defer func() { _ = serverSession.Close() }()
+
+ <-session.Context().Done()
+
+ if err := session.Exit(0); err != nil {
+ log.Debugf("session exit: %v", err)
+ }
+}
+
func generateHostKey() (ssh.Signer, error) {
keyPEM, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
if err != nil {
@@ -250,8 +291,52 @@ func (c *stdioConn) SetWriteDeadline(_ time.Time) error {
return nil
}
-func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, newChan cryptossh.NewChannel, _ ssh.Context) {
- _ = newChan.Reject(cryptossh.Prohibited, "port forwarding not supported in proxy")
+// directTCPIPHandler handles local port forwarding (direct-tcpip channel).
+func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, newChan cryptossh.NewChannel, sshCtx ssh.Context) {
+ var payload struct {
+ DestAddr string
+ DestPort uint32
+ OriginAddr string
+ OriginPort uint32
+ }
+ if err := cryptossh.Unmarshal(newChan.ExtraData(), &payload); err != nil {
+ _, _ = fmt.Fprintf(p.stderr, "parse direct-tcpip payload: %v\n", err)
+ _ = newChan.Reject(cryptossh.ConnectionFailed, "invalid payload")
+ return
+ }
+
+ dest := fmt.Sprintf("%s:%d", payload.DestAddr, payload.DestPort)
+ log.Debugf("local port forwarding: %s", dest)
+
+ backendClient, err := p.getOrCreateBackendClient(sshCtx, sshCtx.User())
+ if err != nil {
+ _, _ = fmt.Fprintf(p.stderr, "backend connection for port forwarding: %v\n", err)
+ _ = newChan.Reject(cryptossh.ConnectionFailed, "backend connection failed")
+ return
+ }
+
+ backendChan, backendReqs, err := backendClient.OpenChannel("direct-tcpip", newChan.ExtraData())
+ if err != nil {
+ _, _ = fmt.Fprintf(p.stderr, "open backend channel for %s: %v\n", dest, err)
+ var openErr *cryptossh.OpenChannelError
+ if errors.As(err, &openErr) {
+ _ = newChan.Reject(openErr.Reason, openErr.Message)
+ } else {
+ _ = newChan.Reject(cryptossh.ConnectionFailed, err.Error())
+ }
+ return
+ }
+ go cryptossh.DiscardRequests(backendReqs)
+
+ clientChan, clientReqs, err := newChan.Accept()
+ if err != nil {
+ log.Debugf("local port forwarding: accept channel: %v", err)
+ _ = backendChan.Close()
+ return
+ }
+ go cryptossh.DiscardRequests(clientReqs)
+
+ nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan)
}
func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) {
@@ -354,12 +439,143 @@ func (p *SSHProxy) runSFTPBridge(ctx context.Context, s ssh.Session, stdin io.Wr
}
}
-func (p *SSHProxy) tcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) {
- return false, []byte("port forwarding not supported in proxy")
+// tcpipForwardHandler handles remote port forwarding (tcpip-forward request).
+func (p *SSHProxy) tcpipForwardHandler(sshCtx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) {
+ var reqPayload struct {
+ Host string
+ Port uint32
+ }
+ if err := cryptossh.Unmarshal(req.Payload, &reqPayload); err != nil {
+ _, _ = fmt.Fprintf(p.stderr, "parse tcpip-forward payload: %v\n", err)
+ return false, nil
+ }
+
+ log.Debugf("tcpip-forward request for %s:%d", reqPayload.Host, reqPayload.Port)
+
+ backendClient, err := p.getOrCreateBackendClient(sshCtx, sshCtx.User())
+ if err != nil {
+ _, _ = fmt.Fprintf(p.stderr, "backend connection for remote port forwarding: %v\n", err)
+ return false, nil
+ }
+
+ ok, payload, err := backendClient.SendRequest(req.Type, req.WantReply, req.Payload)
+ if err != nil {
+ _, _ = fmt.Fprintf(p.stderr, "forward tcpip-forward request for %s:%d: %v\n", reqPayload.Host, reqPayload.Port, err)
+ return false, nil
+ }
+
+ if ok {
+ actualPort := reqPayload.Port
+ if reqPayload.Port == 0 && len(payload) >= 4 {
+ actualPort = binary.BigEndian.Uint32(payload)
+ }
+ log.Debugf("remote port forwarding established for %s:%d", reqPayload.Host, actualPort)
+ p.forwardedChannelsOnce.Do(func() {
+ go p.handleForwardedChannels(sshCtx, backendClient)
+ })
+ }
+
+ return ok, payload
}
-func (p *SSHProxy) cancelTcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) {
- return true, nil
+// cancelTcpipForwardHandler handles cancel-tcpip-forward request.
+func (p *SSHProxy) cancelTcpipForwardHandler(_ ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) {
+ var reqPayload struct {
+ Host string
+ Port uint32
+ }
+ if err := cryptossh.Unmarshal(req.Payload, &reqPayload); err != nil {
+ _, _ = fmt.Fprintf(p.stderr, "parse cancel-tcpip-forward payload: %v\n", err)
+ return false, nil
+ }
+
+ log.Debugf("cancel-tcpip-forward request for %s:%d", reqPayload.Host, reqPayload.Port)
+
+ backendClient := p.getBackendClient()
+ if backendClient == nil {
+ return false, nil
+ }
+
+ ok, payload, err := backendClient.SendRequest(req.Type, req.WantReply, req.Payload)
+ if err != nil {
+ _, _ = fmt.Fprintf(p.stderr, "cancel-tcpip-forward for %s:%d: %v\n", reqPayload.Host, reqPayload.Port, err)
+ return false, nil
+ }
+
+ return ok, payload
+}
+
+// getOrCreateBackendClient returns the existing backend client or creates a new one.
+func (p *SSHProxy) getOrCreateBackendClient(ctx context.Context, user string) (*cryptossh.Client, error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ if p.backendClient != nil {
+ return p.backendClient, nil
+ }
+
+ targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
+ log.Debugf("connecting to backend %s", targetAddr)
+
+ client, err := p.dialBackend(ctx, targetAddr, user, p.jwtToken)
+ if err != nil {
+ return nil, err
+ }
+
+ log.Debugf("backend connection established to %s", targetAddr)
+ p.backendClient = client
+ return client, nil
+}
+
+// getBackendClient returns the existing backend client or nil.
+func (p *SSHProxy) getBackendClient() *cryptossh.Client {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+ return p.backendClient
+}
+
+// handleForwardedChannels handles forwarded-tcpip channels from the backend for remote port forwarding.
+// When the backend receives incoming connections on the forwarded port, it sends them as
+// "forwarded-tcpip" channels which we need to proxy to the client.
+func (p *SSHProxy) handleForwardedChannels(sshCtx ssh.Context, backendClient *cryptossh.Client) {
+ sshConn, ok := sshCtx.Value(ssh.ContextKeyConn).(*cryptossh.ServerConn)
+ if !ok || sshConn == nil {
+ log.Debugf("no SSH connection in context for forwarded channels")
+ return
+ }
+
+ channelChan := backendClient.HandleChannelOpen("forwarded-tcpip")
+ for {
+ select {
+ case <-sshCtx.Done():
+ return
+ case newChannel, ok := <-channelChan:
+ if !ok {
+ return
+ }
+ go p.handleForwardedChannel(sshCtx, sshConn, newChannel)
+ }
+ }
+}
+
+// handleForwardedChannel handles a single forwarded-tcpip channel from the backend.
+func (p *SSHProxy) handleForwardedChannel(sshCtx ssh.Context, sshConn *cryptossh.ServerConn, newChannel cryptossh.NewChannel) {
+ backendChan, backendReqs, err := newChannel.Accept()
+ if err != nil {
+ log.Debugf("remote port forwarding: accept from backend: %v", err)
+ return
+ }
+ go cryptossh.DiscardRequests(backendReqs)
+
+ clientChan, clientReqs, err := sshConn.OpenChannel("forwarded-tcpip", newChannel.ExtraData())
+ if err != nil {
+ log.Debugf("remote port forwarding: open to client: %v", err)
+ _ = backendChan.Close()
+ return
+ }
+ go cryptossh.DiscardRequests(clientReqs)
+
+ nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan)
}
func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) {
diff --git a/client/ssh/proxy/proxy_test.go b/client/ssh/proxy/proxy_test.go
index 582f9c07b..81d588801 100644
--- a/client/ssh/proxy/proxy_test.go
+++ b/client/ssh/proxy/proxy_test.go
@@ -27,9 +27,11 @@ import (
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
+ sshauth "github.com/netbirdio/netbird/client/ssh/auth"
"github.com/netbirdio/netbird/client/ssh/server"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
+ sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
func TestMain(m *testing.M) {
@@ -137,6 +139,21 @@ func TestSSHProxy_Connect(t *testing.T) {
sshServer := server.New(serverConfig)
sshServer.SetAllowRootLogin(true)
+ // Configure SSH authorization for the test user
+ testUsername := testutil.GetTestUsername(t)
+ testJWTUser := "test-username"
+ testUserHash, err := sshuserhash.HashUserID(testJWTUser)
+ require.NoError(t, err)
+
+ authConfig := &sshauth.Config{
+ UserIDClaim: sshauth.DefaultUserIDClaim,
+ AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
+ MachineUsers: map[string][]uint32{
+ testUsername: {0}, // Index 0 in AuthorizedUsers
+ },
+ }
+ sshServer.UpdateSSHAuth(authConfig)
+
sshServerAddr := server.StartTestServer(t, sshServer)
defer func() { _ = sshServer.Stop() }()
@@ -150,10 +167,10 @@ func TestSSHProxy_Connect(t *testing.T) {
mockDaemon.setHostKey(host, hostPubKey)
- validToken := generateValidJWT(t, privateKey, issuer, audience)
+ validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
mockDaemon.setJWTToken(validToken)
- proxyInstance, err := New(mockDaemon.addr, host, port, nil, nil)
+ proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
require.NoError(t, err)
clientConn, proxyConn := net.Pipe()
@@ -347,12 +364,12 @@ func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
return privateKey, jwksJSON
}
-func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string {
+func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string, user string) string {
t.Helper()
claims := jwt.MapClaims{
"iss": issuer,
"aud": audience,
- "sub": "test-user",
+ "sub": user,
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}
diff --git a/client/ssh/server/jwt_test.go b/client/ssh/server/jwt_test.go
index 1f3bac76d..d36d7cbbf 100644
--- a/client/ssh/server/jwt_test.go
+++ b/client/ssh/server/jwt_test.go
@@ -23,10 +23,12 @@ import (
"github.com/stretchr/testify/require"
nbssh "github.com/netbirdio/netbird/client/ssh"
+ sshauth "github.com/netbirdio/netbird/client/ssh/auth"
"github.com/netbirdio/netbird/client/ssh/client"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
+ sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
func TestJWTEnforcement(t *testing.T) {
@@ -577,6 +579,22 @@ func TestJWTAuthentication(t *testing.T) {
tc.setupServer(server)
}
+ // Always set up authorization for test-user to ensure tests fail at JWT validation stage
+ testUserHash, err := sshuserhash.HashUserID("test-user")
+ require.NoError(t, err)
+
+ // Get current OS username for machine user mapping
+ currentUser := testutil.GetTestUsername(t)
+
+ authConfig := &sshauth.Config{
+ UserIDClaim: sshauth.DefaultUserIDClaim,
+ AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
+ MachineUsers: map[string][]uint32{
+ currentUser: {0}, // Allow test-user (index 0) to access current OS user
+ },
+ }
+ server.UpdateSSHAuth(authConfig)
+
serverAddr := StartTestServer(t, server)
defer require.NoError(t, server.Stop())
diff --git a/client/ssh/server/port_forwarding.go b/client/ssh/server/port_forwarding.go
index 6138f9296..c60cf4f58 100644
--- a/client/ssh/server/port_forwarding.go
+++ b/client/ssh/server/port_forwarding.go
@@ -1,25 +1,32 @@
+// Package server implements port forwarding for the SSH server.
+//
+// Security note: Port forwarding runs in the main server process without privilege separation.
+// The attack surface is primarily io.Copy through well-tested standard library code, making it
+// lower risk than shell execution which uses privilege-separated child processes. We enforce
+// user-level port restrictions: non-privileged users cannot bind to ports < 1024.
package server
import (
"encoding/binary"
"fmt"
- "io"
"net"
+ "runtime"
"strconv"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
cryptossh "golang.org/x/crypto/ssh"
+
+ nbssh "github.com/netbirdio/netbird/client/ssh"
)
-// SessionKey uniquely identifies an SSH session
-type SessionKey string
+const privilegedPortThreshold = 1024
-// ConnectionKey uniquely identifies a port forwarding connection within a session
-type ConnectionKey string
+// sessionKey uniquely identifies an SSH session
+type sessionKey string
-// ForwardKey uniquely identifies a port forwarding listener
-type ForwardKey string
+// forwardKey uniquely identifies a port forwarding listener
+type forwardKey string
// tcpipForwardMsg represents the structure for tcpip-forward SSH requests
type tcpipForwardMsg struct {
@@ -47,34 +54,32 @@ func (s *Server) configurePortForwarding(server *ssh.Server) {
allowRemote := s.allowRemotePortForwarding
server.LocalPortForwardingCallback = func(ctx ssh.Context, dstHost string, dstPort uint32) bool {
+ logger := s.getRequestLogger(ctx)
if !allowLocal {
- log.Warnf("local port forwarding denied for %s from %s: disabled by configuration",
- net.JoinHostPort(dstHost, fmt.Sprintf("%d", dstPort)), ctx.RemoteAddr())
+ logger.Warnf("local port forwarding denied for %s:%d: disabled", dstHost, dstPort)
return false
}
if err := s.checkPortForwardingPrivileges(ctx, "local", dstPort); err != nil {
- log.Warnf("local port forwarding denied for %s:%d from %s: %v", dstHost, dstPort, ctx.RemoteAddr(), err)
+ logger.Warnf("local port forwarding denied for %s:%d: %v", dstHost, dstPort, err)
return false
}
- log.Debugf("local port forwarding allowed: %s:%d", dstHost, dstPort)
return true
}
server.ReversePortForwardingCallback = func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
+ logger := s.getRequestLogger(ctx)
if !allowRemote {
- log.Warnf("remote port forwarding denied for %s from %s: disabled by configuration",
- net.JoinHostPort(bindHost, fmt.Sprintf("%d", bindPort)), ctx.RemoteAddr())
+ logger.Warnf("remote port forwarding denied for %s:%d: disabled", bindHost, bindPort)
return false
}
if err := s.checkPortForwardingPrivileges(ctx, "remote", bindPort); err != nil {
- log.Warnf("remote port forwarding denied for %s:%d from %s: %v", bindHost, bindPort, ctx.RemoteAddr(), err)
+ logger.Warnf("remote port forwarding denied for %s:%d: %v", bindHost, bindPort, err)
return false
}
- log.Debugf("remote port forwarding allowed: %s:%d", bindHost, bindPort)
return true
}
@@ -82,23 +87,20 @@ func (s *Server) configurePortForwarding(server *ssh.Server) {
}
// checkPortForwardingPrivileges validates privilege requirements for port forwarding operations.
-// Returns nil if allowed, error if denied.
+// For remote port forwarding (binding), it enforces that non-privileged users cannot bind to
+// ports below 1024, mirroring the restriction they would face if binding directly.
+//
+// Note: FeatureSupportsUserSwitch is true because we accept requests from any authenticated user,
+// though we don't actually switch users - port forwarding runs in the server process. The resolved
+// user is used for privileged port access checks.
func (s *Server) checkPortForwardingPrivileges(ctx ssh.Context, forwardType string, port uint32) error {
if ctx == nil {
return fmt.Errorf("%s port forwarding denied: no context", forwardType)
}
- username := ctx.User()
- remoteAddr := "unknown"
- if ctx.RemoteAddr() != nil {
- remoteAddr = ctx.RemoteAddr().String()
- }
-
- logger := log.WithFields(log.Fields{"user": username, "remote": remoteAddr, "port": port})
-
result := s.CheckPrivileges(PrivilegeCheckRequest{
- RequestedUsername: username,
- FeatureSupportsUserSwitch: false,
+ RequestedUsername: ctx.User(),
+ FeatureSupportsUserSwitch: true,
FeatureName: forwardType + " port forwarding",
})
@@ -106,12 +108,42 @@ func (s *Server) checkPortForwardingPrivileges(ctx ssh.Context, forwardType stri
return result.Error
}
- logger.Debugf("%s port forwarding allowed: user %s validated (port %d)",
- forwardType, result.User.Username, port)
+ if err := s.checkPrivilegedPortAccess(forwardType, port, result); err != nil {
+ return err
+ }
return nil
}
+// checkPrivilegedPortAccess enforces that non-privileged users cannot bind to privileged ports.
+// This applies to remote port forwarding where the server binds a port on behalf of the user.
+// On Windows, there is no privileged port restriction, so this check is skipped.
+func (s *Server) checkPrivilegedPortAccess(forwardType string, port uint32, result PrivilegeCheckResult) error {
+ if runtime.GOOS == "windows" {
+ return nil
+ }
+
+ isBindOperation := forwardType == "remote" || forwardType == "tcpip-forward"
+ if !isBindOperation {
+ return nil
+ }
+
+ // Port 0 means "pick any available port", which will be >= 1024
+ if port == 0 || port >= privilegedPortThreshold {
+ return nil
+ }
+
+ if result.User != nil && isPrivilegedUsername(result.User.Username) {
+ return nil
+ }
+
+ username := "unknown"
+ if result.User != nil {
+ username = result.User.Username
+ }
+ return fmt.Errorf("user %s cannot bind to privileged port %d (requires root)", username, port)
+}
+
// tcpipForwardHandler handles tcpip-forward requests for remote port forwarding.
func (s *Server) tcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) {
logger := s.getRequestLogger(ctx)
@@ -132,8 +164,6 @@ func (s *Server) tcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *crypto
return false, nil
}
- logger.Debugf("tcpip-forward request: %s:%d", payload.Host, payload.Port)
-
sshConn, err := s.getSSHConnection(ctx)
if err != nil {
logger.Warnf("tcpip-forward request denied: %v", err)
@@ -153,8 +183,10 @@ func (s *Server) cancelTcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *
return false, nil
}
- key := ForwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
+ key := forwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
if s.removeRemoteForwardListener(key) {
+ forwardAddr := fmt.Sprintf("-R %s:%d", payload.Host, payload.Port)
+ s.removeConnectionPortForward(ctx.RemoteAddr(), forwardAddr)
logger.Infof("remote port forwarding cancelled: %s:%d", payload.Host, payload.Port)
return true, nil
}
@@ -165,14 +197,11 @@ func (s *Server) cancelTcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *
// handleRemoteForwardListener handles incoming connections for remote port forwarding.
func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, host string, port uint32) {
- log.Debugf("starting remote forward listener handler for %s:%d", host, port)
+ logger := s.getRequestLogger(ctx)
defer func() {
- log.Debugf("cleaning up remote forward listener for %s:%d", host, port)
if err := ln.Close(); err != nil {
- log.Debugf("remote forward listener close error: %v", err)
- } else {
- log.Debugf("remote forward listener closed successfully for %s:%d", host, port)
+ logger.Debugf("remote forward listener close error for %s:%d: %v", host, port, err)
}
}()
@@ -196,28 +225,43 @@ func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, h
select {
case result := <-acceptChan:
if result.err != nil {
- log.Debugf("remote forward accept error: %v", result.err)
+ logger.Debugf("remote forward accept error: %v", result.err)
return
}
go s.handleRemoteForwardConnection(ctx, result.conn, host, port)
case <-ctx.Done():
- log.Debugf("remote forward listener shutting down due to context cancellation for %s:%d", host, port)
+ logger.Debugf("remote forward listener shutting down for %s:%d", host, port)
return
}
}
}
-// getRequestLogger creates a logger with user and remote address context
+// getRequestLogger creates a logger with session/conn and jwt_user context
func (s *Server) getRequestLogger(ctx ssh.Context) *log.Entry {
- remoteAddr := "unknown"
- username := "unknown"
- if ctx != nil {
- if ctx.RemoteAddr() != nil {
- remoteAddr = ctx.RemoteAddr().String()
+ sessionKey := s.findSessionKeyByContext(ctx)
+
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if state, exists := s.sessions[sessionKey]; exists {
+ logger := log.WithField("session", sessionKey)
+ if state.jwtUsername != "" {
+ logger = logger.WithField("jwt_user", state.jwtUsername)
}
- username = ctx.User()
+ return logger
}
- return log.WithFields(log.Fields{"user": username, "remote": remoteAddr})
+
+ if ctx.RemoteAddr() != nil {
+ if connState, exists := s.connections[connKey(ctx.RemoteAddr().String())]; exists {
+ return s.connLogger(connState)
+ }
+ }
+
+ remoteAddr := "unknown"
+ if ctx.RemoteAddr() != nil {
+ remoteAddr = ctx.RemoteAddr().String()
+ }
+ return log.WithField("session", fmt.Sprintf("%s@%s", ctx.User(), remoteAddr))
}
// isRemotePortForwardingAllowed checks if remote port forwarding is enabled
@@ -227,6 +271,13 @@ func (s *Server) isRemotePortForwardingAllowed() bool {
return s.allowRemotePortForwarding
}
+// isPortForwardingEnabled checks if any port forwarding (local or remote) is enabled
+func (s *Server) isPortForwardingEnabled() bool {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ return s.allowLocalPortForwarding || s.allowRemotePortForwarding
+}
+
// parseTcpipForwardRequest parses the SSH request payload
func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) {
var payload tcpipForwardMsg
@@ -267,10 +318,11 @@ func (s *Server) setupDirectForward(ctx ssh.Context, logger *log.Entry, sshConn
logger.Debugf("tcpip-forward allocated port %d for %s", actualPort, payload.Host)
}
- key := ForwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
+ key := forwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
s.storeRemoteForwardListener(key, ln)
- s.markConnectionActivePortForward(sshConn, ctx.User(), ctx.RemoteAddr().String())
+ forwardAddr := fmt.Sprintf("-R %s:%d", payload.Host, actualPort)
+ s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
go s.handleRemoteForwardListener(ctx, ln, payload.Host, actualPort)
response := make([]byte, 4)
@@ -288,44 +340,34 @@ type acceptResult struct {
// handleRemoteForwardConnection handles a single remote port forwarding connection
func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, host string, port uint32) {
- sessionKey := s.findSessionKeyByContext(ctx)
- connID := fmt.Sprintf("pf-%s->%s:%d", conn.RemoteAddr(), host, port)
- logger := log.WithFields(log.Fields{
- "session": sessionKey,
- "conn": connID,
- })
+ logger := s.getRequestLogger(ctx)
- defer func() {
- if err := conn.Close(); err != nil {
- logger.Debugf("connection close error: %v", err)
- }
- }()
-
- sshConn := ctx.Value(ssh.ContextKeyConn).(*cryptossh.ServerConn)
- if sshConn == nil {
+ sshConn, ok := ctx.Value(ssh.ContextKeyConn).(*cryptossh.ServerConn)
+ if !ok || sshConn == nil {
logger.Debugf("remote forward: no SSH connection in context")
+ _ = conn.Close()
return
}
remoteAddr, ok := conn.RemoteAddr().(*net.TCPAddr)
if !ok {
logger.Warnf("remote forward: non-TCP connection type: %T", conn.RemoteAddr())
+ _ = conn.Close()
return
}
- channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr, logger)
+ channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr)
if err != nil {
- logger.Debugf("open forward channel: %v", err)
+ logger.Debugf("open forward channel for %s:%d: %v", host, port, err)
+ _ = conn.Close()
return
}
- s.proxyForwardConnection(ctx, logger, conn, channel)
+ nbssh.BidirectionalCopyWithContext(logger, ctx, conn, channel)
}
// openForwardChannel creates an SSH forwarded-tcpip channel
-func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string, port uint32, remoteAddr *net.TCPAddr, logger *log.Entry) (cryptossh.Channel, error) {
- logger.Tracef("opening forwarded-tcpip channel for %s:%d", host, port)
-
+func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string, port uint32, remoteAddr *net.TCPAddr) (cryptossh.Channel, error) {
payload := struct {
ConnectedAddress string
ConnectedPort uint32
@@ -346,41 +388,3 @@ func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string,
go cryptossh.DiscardRequests(reqs)
return channel, nil
}
-
-// proxyForwardConnection handles bidirectional data transfer between connection and SSH channel
-func (s *Server) proxyForwardConnection(ctx ssh.Context, logger *log.Entry, conn net.Conn, channel cryptossh.Channel) {
- done := make(chan struct{}, 2)
-
- go func() {
- if _, err := io.Copy(channel, conn); err != nil {
- logger.Debugf("copy error (conn->channel): %v", err)
- }
- done <- struct{}{}
- }()
-
- go func() {
- if _, err := io.Copy(conn, channel); err != nil {
- logger.Debugf("copy error (channel->conn): %v", err)
- }
- done <- struct{}{}
- }()
-
- select {
- case <-ctx.Done():
- logger.Debugf("session ended, closing connections")
- case <-done:
- // First copy finished, wait for second copy or context cancellation
- select {
- case <-ctx.Done():
- logger.Debugf("session ended, closing connections")
- case <-done:
- }
- }
-
- if err := channel.Close(); err != nil {
- logger.Debugf("channel close error: %v", err)
- }
- if err := conn.Close(); err != nil {
- logger.Debugf("connection close error: %v", err)
- }
-}
diff --git a/client/ssh/server/server.go b/client/ssh/server/server.go
index 37763ee0e..3a8568979 100644
--- a/client/ssh/server/server.go
+++ b/client/ssh/server/server.go
@@ -9,6 +9,7 @@ import (
"io"
"net"
"net/netip"
+ "slices"
"strings"
"sync"
"time"
@@ -21,6 +22,7 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
+ sshauth "github.com/netbirdio/netbird/client/ssh/auth"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/auth/jwt"
@@ -39,6 +41,11 @@ const (
msgPrivilegedUserDisabled = "privileged user login is disabled"
+ cmdInteractiveShell = ""
+ cmdPortForwarding = ""
+ cmdSFTP = ""
+ cmdNonInteractive = ""
+
// DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server
DefaultJWTMaxTokenAge = 5 * 60
)
@@ -89,10 +96,10 @@ func logSessionExitError(logger *log.Entry, err error) {
}
}
-// safeLogCommand returns a safe representation of the command for logging
+// safeLogCommand returns a safe representation of the command for logging.
func safeLogCommand(cmd []string) string {
if len(cmd) == 0 {
- return ""
+ return cmdInteractiveShell
}
if len(cmd) == 1 {
return cmd[0]
@@ -100,26 +107,50 @@ func safeLogCommand(cmd []string) string {
return fmt.Sprintf("%s [%d args]", cmd[0], len(cmd)-1)
}
-type sshConnectionState struct {
- hasActivePortForward bool
- username string
- remoteAddr string
+// connState tracks the state of an SSH connection for port forwarding and status display.
+type connState struct {
+ username string
+ remoteAddr net.Addr
+ portForwards []string
+ jwtUsername string
}
+// authKey uniquely identifies an authentication attempt by username and remote address.
+// Used to temporarily store JWT username between passwordHandler and sessionHandler.
type authKey string
+// connKey uniquely identifies an SSH connection by its remote address.
+// Used to track authenticated connections for status display and port forwarding.
+type connKey string
+
func newAuthKey(username string, remoteAddr net.Addr) authKey {
return authKey(fmt.Sprintf("%s@%s", username, remoteAddr.String()))
}
+// sessionState tracks an active SSH session (shell, command, or subsystem like SFTP).
+type sessionState struct {
+ session ssh.Session
+ sessionType string
+ jwtUsername string
+}
+
type Server struct {
- sshServer *ssh.Server
- mu sync.RWMutex
- hostKeyPEM []byte
- sessions map[SessionKey]ssh.Session
- sessionCancels map[ConnectionKey]context.CancelFunc
- sessionJWTUsers map[SessionKey]string
- pendingAuthJWT map[authKey]string
+ sshServer *ssh.Server
+ listener net.Listener
+ mu sync.RWMutex
+ hostKeyPEM []byte
+
+ // sessions tracks active SSH sessions (shell, command, SFTP).
+ // These are created when a client opens a session channel and requests shell/exec/subsystem.
+ sessions map[sessionKey]*sessionState
+
+ // pendingAuthJWT temporarily stores JWT username during the auth→session handoff.
+ // Populated in passwordHandler, consumed in sessionHandler/sftpSubsystemHandler.
+ pendingAuthJWT map[authKey]string
+
+ // connections tracks all SSH connections by their remote address.
+ // Populated at authentication time, stores JWT username and port forwards for status display.
+ connections map[connKey]*connState
allowLocalPortForwarding bool
allowRemotePortForwarding bool
@@ -131,13 +162,14 @@ type Server struct {
wgAddress wgaddr.Address
- remoteForwardListeners map[ForwardKey]net.Listener
- sshConnections map[*cryptossh.ServerConn]*sshConnectionState
+ remoteForwardListeners map[forwardKey]net.Listener
jwtValidator *jwt.Validator
jwtExtractor *jwt.ClaimsExtractor
jwtConfig *JWTConfig
+ authorizer *sshauth.Authorizer
+
suSupportsPty bool
loginIsUtilLinux bool
}
@@ -164,6 +196,7 @@ type SessionInfo struct {
RemoteAddress string
Command string
JWTUsername string
+ PortForwards []string
}
// New creates an SSH server instance with the provided host key and optional JWT configuration
@@ -172,13 +205,13 @@ func New(config *Config) *Server {
s := &Server{
mu: sync.RWMutex{},
hostKeyPEM: config.HostKeyPEM,
- sessions: make(map[SessionKey]ssh.Session),
- sessionJWTUsers: make(map[SessionKey]string),
+ sessions: make(map[sessionKey]*sessionState),
pendingAuthJWT: make(map[authKey]string),
- remoteForwardListeners: make(map[ForwardKey]net.Listener),
- sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState),
+ remoteForwardListeners: make(map[forwardKey]net.Listener),
+ connections: make(map[connKey]*connState),
jwtEnabled: config.JWT != nil,
jwtConfig: config.JWT,
+ authorizer: sshauth.NewAuthorizer(), // Initialize with empty config
}
return s
@@ -207,6 +240,7 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
return fmt.Errorf("create SSH server: %w", err)
}
+ s.listener = ln
s.sshServer = sshServer
log.Infof("SSH server started on %s", addrDesc)
@@ -259,16 +293,11 @@ func (s *Server) Stop() error {
}
s.sshServer = nil
+ s.listener = nil
maps.Clear(s.sessions)
- maps.Clear(s.sessionJWTUsers)
maps.Clear(s.pendingAuthJWT)
- maps.Clear(s.sshConnections)
-
- for _, cancelFunc := range s.sessionCancels {
- cancelFunc()
- }
- maps.Clear(s.sessionCancels)
+ maps.Clear(s.connections)
for _, listener := range s.remoteForwardListeners {
if err := listener.Close(); err != nil {
@@ -280,32 +309,82 @@ func (s *Server) Stop() error {
return nil
}
-// GetStatus returns the current status of the SSH server and active sessions
+// Addr returns the address the SSH server is listening on, or nil if the server is not running
+func (s *Server) Addr() net.Addr {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if s.listener == nil {
+ return nil
+ }
+
+ return s.listener.Addr()
+}
+
+// GetStatus returns the current status of the SSH server and active sessions.
func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) {
s.mu.RLock()
defer s.mu.RUnlock()
enabled = s.sshServer != nil
+ reportedAddrs := make(map[string]bool)
- for sessionKey, session := range s.sessions {
- cmd := ""
- if len(session.Command()) > 0 {
- cmd = safeLogCommand(session.Command())
+ for _, state := range s.sessions {
+ info := s.buildSessionInfo(state)
+ reportedAddrs[info.RemoteAddress] = true
+ sessions = append(sessions, info)
+ }
+
+ // Add authenticated connections without sessions (e.g., -N/-T or port-forwarding only)
+ for key, connState := range s.connections {
+ remoteAddr := string(key)
+ if reportedAddrs[remoteAddr] {
+ continue
+ }
+ cmd := cmdNonInteractive
+ if len(connState.portForwards) > 0 {
+ cmd = cmdPortForwarding
}
-
- jwtUsername := s.sessionJWTUsers[sessionKey]
-
sessions = append(sessions, SessionInfo{
- Username: session.User(),
- RemoteAddress: session.RemoteAddr().String(),
+ Username: connState.username,
+ RemoteAddress: remoteAddr,
Command: cmd,
- JWTUsername: jwtUsername,
+ JWTUsername: connState.jwtUsername,
+ PortForwards: connState.portForwards,
})
}
return enabled, sessions
}
+func (s *Server) buildSessionInfo(state *sessionState) SessionInfo {
+ session := state.session
+ cmd := state.sessionType
+ if cmd == "" {
+ cmd = safeLogCommand(session.Command())
+ }
+
+ remoteAddr := session.RemoteAddr().String()
+ info := SessionInfo{
+ Username: session.User(),
+ RemoteAddress: remoteAddr,
+ Command: cmd,
+ JWTUsername: state.jwtUsername,
+ }
+
+ connState, exists := s.connections[connKey(remoteAddr)]
+ if !exists {
+ return info
+ }
+
+ info.PortForwards = connState.portForwards
+ if len(connState.portForwards) > 0 && (cmd == cmdInteractiveShell || cmd == cmdNonInteractive) {
+ info.Command = cmdPortForwarding
+ }
+
+ return info
+}
+
// SetNetstackNet sets the netstack network for userspace networking
func (s *Server) SetNetstackNet(net *netstack.Net) {
s.mu.Lock()
@@ -320,6 +399,19 @@ func (s *Server) SetNetworkValidation(addr wgaddr.Address) {
s.wgAddress = addr
}
+// UpdateSSHAuth updates the SSH fine-grained access control configuration
+// This should be called when network map updates include new SSH auth configuration
+func (s *Server) UpdateSSHAuth(config *sshauth.Config) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // Reset JWT validator/extractor to pick up new userIDClaim
+ s.jwtValidator = nil
+ s.jwtExtractor = nil
+
+ s.authorizer.Update(config)
+}
+
// ensureJWTValidator initializes the JWT validator and extractor if not already initialized
func (s *Server) ensureJWTValidator() error {
s.mu.RLock()
@@ -328,6 +420,7 @@ func (s *Server) ensureJWTValidator() error {
return nil
}
config := s.jwtConfig
+ authorizer := s.authorizer
s.mu.RUnlock()
if config == nil {
@@ -343,9 +436,16 @@ func (s *Server) ensureJWTValidator() error {
true,
)
- extractor := jwt.NewClaimsExtractor(
+ // Use custom userIDClaim from authorizer if available
+ extractorOptions := []jwt.ClaimsExtractorOption{
jwt.WithAudience(config.Audience),
- )
+ }
+ if authorizer.GetUserIDClaim() != "" {
+ extractorOptions = append(extractorOptions, jwt.WithUserIDClaim(authorizer.GetUserIDClaim()))
+ log.Debugf("Using custom user ID claim: %s", authorizer.GetUserIDClaim())
+ }
+
+ extractor := jwt.NewClaimsExtractor(extractorOptions...)
s.mu.Lock()
defer s.mu.Unlock()
@@ -493,59 +593,131 @@ func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]int
}
func (s *Server) passwordHandler(ctx ssh.Context, password string) bool {
+ osUsername := ctx.User()
+ remoteAddr := ctx.RemoteAddr()
+ logger := s.getRequestLogger(ctx)
+
if err := s.ensureJWTValidator(); err != nil {
- log.Errorf("JWT validator initialization failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
+ logger.Errorf("JWT validator initialization failed: %v", err)
return false
}
token, err := s.validateJWTToken(password)
if err != nil {
- log.Warnf("JWT authentication failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
+ logger.Warnf("JWT authentication failed: %v", err)
return false
}
userAuth, err := s.extractAndValidateUser(token)
if err != nil {
- log.Warnf("User validation failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
+ logger.Warnf("user validation failed: %v", err)
return false
}
- key := newAuthKey(ctx.User(), ctx.RemoteAddr())
+ logger = logger.WithField("jwt_user", userAuth.UserId)
+
+ s.mu.RLock()
+ authorizer := s.authorizer
+ s.mu.RUnlock()
+
+ msg, err := authorizer.Authorize(userAuth.UserId, osUsername)
+ if err != nil {
+ logger.Warnf("SSH auth denied: %v", err)
+ return false
+ }
+
+ logger.Infof("SSH auth %s", msg)
+
+ key := newAuthKey(osUsername, remoteAddr)
+ remoteAddrStr := ctx.RemoteAddr().String()
s.mu.Lock()
s.pendingAuthJWT[key] = userAuth.UserId
+ s.connections[connKey(remoteAddrStr)] = &connState{
+ username: ctx.User(),
+ remoteAddr: ctx.RemoteAddr(),
+ jwtUsername: userAuth.UserId,
+ }
s.mu.Unlock()
- log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", ctx.User(), userAuth.UserId, ctx.RemoteAddr())
return true
}
-func (s *Server) markConnectionActivePortForward(sshConn *cryptossh.ServerConn, username, remoteAddr string) {
+func (s *Server) addConnectionPortForward(username string, remoteAddr net.Addr, forwardAddr string) {
s.mu.Lock()
defer s.mu.Unlock()
- if state, exists := s.sshConnections[sshConn]; exists {
- state.hasActivePortForward = true
- } else {
- s.sshConnections[sshConn] = &sshConnectionState{
- hasActivePortForward: true,
- username: username,
- remoteAddr: remoteAddr,
+ key := connKey(remoteAddr.String())
+ if state, exists := s.connections[key]; exists {
+ if !slices.Contains(state.portForwards, forwardAddr) {
+ state.portForwards = append(state.portForwards, forwardAddr)
}
+ return
+ }
+
+ // Connection not in connections (non-JWT auth path)
+ s.connections[key] = &connState{
+ username: username,
+ remoteAddr: remoteAddr,
+ portForwards: []string{forwardAddr},
+ jwtUsername: s.pendingAuthJWT[newAuthKey(username, remoteAddr)],
}
}
-func (s *Server) connectionCloseHandler(conn net.Conn, err error) {
- // We can't extract the SSH connection from net.Conn directly
- // Connection cleanup will happen during session cleanup or via timeout
- log.Debugf("SSH connection failed for %s: %v", conn.RemoteAddr(), err)
+func (s *Server) removeConnectionPortForward(remoteAddr net.Addr, forwardAddr string) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ state, exists := s.connections[connKey(remoteAddr.String())]
+ if !exists {
+ return
+ }
+
+ state.portForwards = slices.DeleteFunc(state.portForwards, func(addr string) bool {
+ return addr == forwardAddr
+ })
}
-func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey {
+// trackedConn wraps a net.Conn to detect when it closes
+type trackedConn struct {
+ net.Conn
+ server *Server
+ remoteAddr string
+ onceClose sync.Once
+}
+
+func (c *trackedConn) Close() error {
+ err := c.Conn.Close()
+ c.onceClose.Do(func() {
+ c.server.handleConnectionClose(c.remoteAddr)
+ })
+ return err
+}
+
+func (s *Server) handleConnectionClose(remoteAddr string) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ key := connKey(remoteAddr)
+ state, exists := s.connections[key]
+ if exists && len(state.portForwards) > 0 {
+ s.connLogger(state).Info("port forwarding connection closed")
+ }
+ delete(s.connections, key)
+}
+
+func (s *Server) connLogger(state *connState) *log.Entry {
+ logger := log.WithField("session", fmt.Sprintf("%s@%s", state.username, state.remoteAddr))
+ if state.jwtUsername != "" {
+ logger = logger.WithField("jwt_user", state.jwtUsername)
+ }
+ return logger
+}
+
+func (s *Server) findSessionKeyByContext(ctx ssh.Context) sessionKey {
if ctx == nil {
return "unknown"
}
- // Try to match by SSH connection
sshConn := ctx.Value(ssh.ContextKeyConn)
if sshConn == nil {
return "unknown"
@@ -554,19 +726,14 @@ func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey {
s.mu.RLock()
defer s.mu.RUnlock()
- // Look through sessions to find one with matching connection
- for sessionKey, session := range s.sessions {
- if session.Context().Value(ssh.ContextKeyConn) == sshConn {
+ for sessionKey, state := range s.sessions {
+ if state.session.Context().Value(ssh.ContextKeyConn) == sshConn {
return sessionKey
}
}
- // If no session found, this might be during early connection setup
- // Return a temporary key that we'll fix up later
if ctx.User() != "" && ctx.RemoteAddr() != nil {
- tempKey := SessionKey(fmt.Sprintf("%s@%s", ctx.User(), ctx.RemoteAddr().String()))
- log.Debugf("Using temporary session key for early port forward tracking: %s (will be updated when session established)", tempKey)
- return tempKey
+ return sessionKey(fmt.Sprintf("%s@%s", ctx.User(), ctx.RemoteAddr().String()))
}
return "unknown"
@@ -607,7 +774,11 @@ func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
}
log.Infof("SSH connection from NetBird peer %s allowed", tcpAddr)
- return conn
+ return &trackedConn{
+ Conn: conn,
+ server: s,
+ remoteAddr: conn.RemoteAddr().String(),
+ }
}
func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
@@ -635,9 +806,8 @@ func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
"tcpip-forward": s.tcpipForwardHandler,
"cancel-tcpip-forward": s.cancelTcpipForwardHandler,
},
- ConnCallback: s.connectionValidator,
- ConnectionFailedCallback: s.connectionCloseHandler,
- Version: serverVersion,
+ ConnCallback: s.connectionValidator,
+ Version: serverVersion,
}
if s.jwtEnabled {
@@ -653,13 +823,13 @@ func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
return server, nil
}
-func (s *Server) storeRemoteForwardListener(key ForwardKey, ln net.Listener) {
+func (s *Server) storeRemoteForwardListener(key forwardKey, ln net.Listener) {
s.mu.Lock()
defer s.mu.Unlock()
s.remoteForwardListeners[key] = ln
}
-func (s *Server) removeRemoteForwardListener(key ForwardKey) bool {
+func (s *Server) removeRemoteForwardListener(key forwardKey) bool {
s.mu.Lock()
defer s.mu.Unlock()
@@ -677,6 +847,8 @@ func (s *Server) removeRemoteForwardListener(key ForwardKey) bool {
}
func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, newChan cryptossh.NewChannel, ctx ssh.Context) {
+ logger := s.getRequestLogger(ctx)
+
var payload struct {
Host string
Port uint32
@@ -686,7 +858,7 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn,
if err := cryptossh.Unmarshal(newChan.ExtraData(), &payload); err != nil {
if err := newChan.Reject(cryptossh.ConnectionFailed, "parse payload"); err != nil {
- log.Debugf("channel reject error: %v", err)
+ logger.Debugf("channel reject error: %v", err)
}
return
}
@@ -696,19 +868,20 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn,
s.mu.RUnlock()
if !allowLocal {
- log.Warnf("local port forwarding denied for %s:%d: disabled by configuration", payload.Host, payload.Port)
+ logger.Warnf("local port forwarding denied for %s:%d: disabled", payload.Host, payload.Port)
_ = newChan.Reject(cryptossh.Prohibited, "local port forwarding disabled")
return
}
- // Check privilege requirements for the destination port
if err := s.checkPortForwardingPrivileges(ctx, "local", payload.Port); err != nil {
- log.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err)
+ logger.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err)
_ = newChan.Reject(cryptossh.Prohibited, "insufficient privileges")
return
}
- log.Infof("local port forwarding: %s:%d", payload.Host, payload.Port)
+ forwardAddr := fmt.Sprintf("-L %s:%d", payload.Host, payload.Port)
+ s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
+ logger.Infof("local port forwarding: %s:%d", payload.Host, payload.Port)
ssh.DirectTCPIPHandler(srv, conn, newChan, ctx)
}
diff --git a/client/ssh/server/server_config_test.go b/client/ssh/server/server_config_test.go
index 24e455025..d85d85a51 100644
--- a/client/ssh/server/server_config_test.go
+++ b/client/ssh/server/server_config_test.go
@@ -224,6 +224,96 @@ func TestServer_PortForwardingRestriction(t *testing.T) {
}
}
+func TestServer_PrivilegedPortAccess(t *testing.T) {
+ hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
+ require.NoError(t, err)
+
+ serverConfig := &Config{
+ HostKeyPEM: hostKey,
+ }
+ server := New(serverConfig)
+ server.SetAllowRemotePortForwarding(true)
+
+ tests := []struct {
+ name string
+ forwardType string
+ port uint32
+ username string
+ expectError bool
+ errorMsg string
+ skipOnWindows bool
+ }{
+ {
+ name: "non-root user remote forward privileged port",
+ forwardType: "remote",
+ port: 80,
+ username: "testuser",
+ expectError: true,
+ errorMsg: "cannot bind to privileged port",
+ skipOnWindows: true,
+ },
+ {
+ name: "non-root user tcpip-forward privileged port",
+ forwardType: "tcpip-forward",
+ port: 443,
+ username: "testuser",
+ expectError: true,
+ errorMsg: "cannot bind to privileged port",
+ skipOnWindows: true,
+ },
+ {
+ name: "non-root user remote forward unprivileged port",
+ forwardType: "remote",
+ port: 8080,
+ username: "testuser",
+ expectError: false,
+ },
+ {
+ name: "non-root user remote forward port 0",
+ forwardType: "remote",
+ port: 0,
+ username: "testuser",
+ expectError: false,
+ },
+ {
+ name: "root user remote forward privileged port",
+ forwardType: "remote",
+ port: 22,
+ username: "root",
+ expectError: false,
+ },
+ {
+ name: "local forward privileged port allowed for non-root",
+ forwardType: "local",
+ port: 80,
+ username: "testuser",
+ expectError: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if tt.skipOnWindows && runtime.GOOS == "windows" {
+ t.Skip("Windows does not have privileged port restrictions")
+ }
+
+ result := PrivilegeCheckResult{
+ Allowed: true,
+ User: &user.User{Username: tt.username},
+ }
+
+ err := server.checkPrivilegedPortAccess(tt.forwardType, tt.port, result)
+
+ if tt.expectError {
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errorMsg)
+ } else {
+ require.NoError(t, err)
+ }
+ })
+ }
+}
+
func TestServer_PortConflictHandling(t *testing.T) {
// Test that multiple sessions requesting the same local port are handled naturally by the OS
// Get current user for SSH connection
@@ -392,3 +482,95 @@ func TestServer_IsPrivilegedUser(t *testing.T) {
})
}
}
+
+func TestServer_PortForwardingOnlySession(t *testing.T) {
+ // Test that sessions without PTY and command are allowed when port forwarding is enabled
+ currentUser, err := user.Current()
+ require.NoError(t, err, "Should be able to get current user")
+
+ // Generate host key for server
+ hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
+ require.NoError(t, err)
+
+ tests := []struct {
+ name string
+ allowLocalForwarding bool
+ allowRemoteForwarding bool
+ expectAllowed bool
+ description string
+ }{
+ {
+ name: "session_allowed_with_local_forwarding",
+ allowLocalForwarding: true,
+ allowRemoteForwarding: false,
+ expectAllowed: true,
+ description: "Port-forwarding-only session should be allowed when local forwarding is enabled",
+ },
+ {
+ name: "session_allowed_with_remote_forwarding",
+ allowLocalForwarding: false,
+ allowRemoteForwarding: true,
+ expectAllowed: true,
+ description: "Port-forwarding-only session should be allowed when remote forwarding is enabled",
+ },
+ {
+ name: "session_allowed_with_both",
+ allowLocalForwarding: true,
+ allowRemoteForwarding: true,
+ expectAllowed: true,
+ description: "Port-forwarding-only session should be allowed when both forwarding types enabled",
+ },
+ {
+ name: "session_denied_without_forwarding",
+ allowLocalForwarding: false,
+ allowRemoteForwarding: false,
+ expectAllowed: false,
+ description: "Port-forwarding-only session should be denied when all forwarding is disabled",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ serverConfig := &Config{
+ HostKeyPEM: hostKey,
+ JWT: nil,
+ }
+ server := New(serverConfig)
+ server.SetAllowRootLogin(true)
+ server.SetAllowLocalPortForwarding(tt.allowLocalForwarding)
+ server.SetAllowRemotePortForwarding(tt.allowRemoteForwarding)
+
+ serverAddr := StartTestServer(t, server)
+ defer func() {
+ _ = server.Stop()
+ }()
+
+ // Connect to the server without requesting PTY or command
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ client, err := sshclient.Dial(ctx, serverAddr, currentUser.Username, sshclient.DialOptions{
+ InsecureSkipVerify: true,
+ })
+ require.NoError(t, err)
+ defer func() {
+ _ = client.Close()
+ }()
+
+ // Execute a command without PTY - this simulates ssh -T with no command
+ // The server should either allow it (port forwarding enabled) or reject it
+ output, err := client.ExecuteCommand(ctx, "")
+ if tt.expectAllowed {
+ // When allowed, the session stays open until cancelled
+ // ExecuteCommand with empty command should return without error
+ assert.NoError(t, err, "Session should be allowed when port forwarding is enabled")
+ assert.NotContains(t, output, "port forwarding is disabled",
+ "Output should not contain port forwarding disabled message")
+ } else if err != nil {
+ // When denied, we expect an error message about port forwarding being disabled
+ assert.Contains(t, err.Error(), "port forwarding is disabled",
+ "Should get port forwarding disabled message")
+ }
+ })
+ }
+}
diff --git a/client/ssh/server/session_handlers.go b/client/ssh/server/session_handlers.go
index 4e6d72098..3fd578064 100644
--- a/client/ssh/server/session_handlers.go
+++ b/client/ssh/server/session_handlers.go
@@ -6,37 +6,45 @@ import (
"errors"
"fmt"
"io"
- "strings"
"time"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
- cryptossh "golang.org/x/crypto/ssh"
)
+// associateJWTUsername extracts pending JWT username for the session and associates it with the session state.
+// Returns the JWT username (empty if none) for logging purposes.
+func (s *Server) associateJWTUsername(sess ssh.Session, sessionKey sessionKey) string {
+ key := newAuthKey(sess.User(), sess.RemoteAddr())
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ jwtUsername := s.pendingAuthJWT[key]
+ if jwtUsername == "" {
+ return ""
+ }
+
+ if state, exists := s.sessions[sessionKey]; exists {
+ state.jwtUsername = jwtUsername
+ }
+ delete(s.pendingAuthJWT, key)
+ return jwtUsername
+}
+
// sessionHandler handles SSH sessions
func (s *Server) sessionHandler(session ssh.Session) {
- sessionKey := s.registerSession(session)
-
- key := newAuthKey(session.User(), session.RemoteAddr())
- s.mu.Lock()
- jwtUsername := s.pendingAuthJWT[key]
- if jwtUsername != "" {
- s.sessionJWTUsers[sessionKey] = jwtUsername
- delete(s.pendingAuthJWT, key)
- }
- s.mu.Unlock()
+ sessionKey := s.registerSession(session, "")
+ jwtUsername := s.associateJWTUsername(session, sessionKey)
logger := log.WithField("session", sessionKey)
if jwtUsername != "" {
logger = logger.WithField("jwt_user", jwtUsername)
- logger.Infof("SSH session started (JWT user: %s)", jwtUsername)
- } else {
- logger.Infof("SSH session started")
}
+ logger.Info("SSH session started")
sessionStart := time.Now()
- defer s.unregisterSession(sessionKey, session)
+ defer s.unregisterSession(sessionKey)
defer func() {
duration := time.Since(sessionStart).Round(time.Millisecond)
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
@@ -65,27 +73,52 @@ func (s *Server) sessionHandler(session ssh.Session) {
// ssh - non-Pty command execution
s.handleCommand(logger, session, privilegeResult, nil)
default:
- s.rejectInvalidSession(logger, session)
+ // ssh -T (or ssh -N) - no PTY, no command
+ s.handleNonInteractiveSession(logger, session)
}
}
-func (s *Server) rejectInvalidSession(logger *log.Entry, session ssh.Session) {
- if _, err := io.WriteString(session, "no command specified and Pty not requested\n"); err != nil {
- logger.Debugf(errWriteSession, err)
+// handleNonInteractiveSession handles sessions that have no PTY and no command.
+// These are typically used for port forwarding (ssh -L/-R) or tunneling (ssh -N).
+func (s *Server) handleNonInteractiveSession(logger *log.Entry, session ssh.Session) {
+ s.updateSessionType(session, cmdNonInteractive)
+
+ if !s.isPortForwardingEnabled() {
+ if _, err := io.WriteString(session, "port forwarding is disabled on this server\n"); err != nil {
+ logger.Debugf(errWriteSession, err)
+ }
+ if err := session.Exit(1); err != nil {
+ logSessionExitError(logger, err)
+ }
+ logger.Infof("rejected non-interactive session: port forwarding disabled")
+ return
}
- if err := session.Exit(1); err != nil {
+
+ <-session.Context().Done()
+
+ if err := session.Exit(0); err != nil {
logSessionExitError(logger, err)
}
- logger.Infof("rejected non-Pty session without command from %s", session.RemoteAddr())
}
-func (s *Server) registerSession(session ssh.Session) SessionKey {
+func (s *Server) updateSessionType(session ssh.Session, sessionType string) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ for _, state := range s.sessions {
+ if state.session == session {
+ state.sessionType = sessionType
+ return
+ }
+ }
+}
+
+func (s *Server) registerSession(session ssh.Session, sessionType string) sessionKey {
sessionID := session.Context().Value(ssh.ContextKeySessionID)
if sessionID == nil {
sessionID = fmt.Sprintf("%p", session)
}
- // Create a short 4-byte identifier from the full session ID
hasher := sha256.New()
hasher.Write([]byte(fmt.Sprintf("%v", sessionID)))
hash := hasher.Sum(nil)
@@ -93,43 +126,23 @@ func (s *Server) registerSession(session ssh.Session) SessionKey {
remoteAddr := session.RemoteAddr().String()
username := session.User()
- sessionKey := SessionKey(fmt.Sprintf("%s@%s-%s", username, remoteAddr, shortID))
+ sessionKey := sessionKey(fmt.Sprintf("%s@%s-%s", username, remoteAddr, shortID))
s.mu.Lock()
- s.sessions[sessionKey] = session
+ s.sessions[sessionKey] = &sessionState{
+ session: session,
+ sessionType: sessionType,
+ }
s.mu.Unlock()
return sessionKey
}
-func (s *Server) unregisterSession(sessionKey SessionKey, session ssh.Session) {
+func (s *Server) unregisterSession(sessionKey sessionKey) {
s.mu.Lock()
+ defer s.mu.Unlock()
+
delete(s.sessions, sessionKey)
- delete(s.sessionJWTUsers, sessionKey)
-
- // Cancel all port forwarding connections for this session
- var connectionsToCancel []ConnectionKey
- for key := range s.sessionCancels {
- if strings.HasPrefix(string(key), string(sessionKey)+"-") {
- connectionsToCancel = append(connectionsToCancel, key)
- }
- }
-
- for _, key := range connectionsToCancel {
- if cancelFunc, exists := s.sessionCancels[key]; exists {
- log.WithField("session", sessionKey).Debugf("cancelling port forwarding context: %s", key)
- cancelFunc()
- delete(s.sessionCancels, key)
- }
- }
-
- if sshConnValue := session.Context().Value(ssh.ContextKeyConn); sshConnValue != nil {
- if sshConn, ok := sshConnValue.(*cryptossh.ServerConn); ok {
- delete(s.sshConnections, sshConn)
- }
- }
-
- s.mu.Unlock()
}
func (s *Server) handlePrivError(logger *log.Entry, session ssh.Session, err error) {
diff --git a/client/ssh/server/sftp.go b/client/ssh/server/sftp.go
index c2b9f552b..199444abb 100644
--- a/client/ssh/server/sftp.go
+++ b/client/ssh/server/sftp.go
@@ -18,14 +18,26 @@ func (s *Server) SetAllowSFTP(allow bool) {
// sftpSubsystemHandler handles SFTP subsystem requests
func (s *Server) sftpSubsystemHandler(sess ssh.Session) {
+ sessionKey := s.registerSession(sess, cmdSFTP)
+ defer s.unregisterSession(sessionKey)
+
+ jwtUsername := s.associateJWTUsername(sess, sessionKey)
+
+ logger := log.WithField("session", sessionKey)
+ if jwtUsername != "" {
+ logger = logger.WithField("jwt_user", jwtUsername)
+ }
+ logger.Info("SFTP session started")
+ defer logger.Info("SFTP session closed")
+
s.mu.RLock()
allowSFTP := s.allowSFTP
s.mu.RUnlock()
if !allowSFTP {
- log.Debugf("SFTP subsystem request denied: SFTP disabled")
+ logger.Debug("SFTP subsystem request denied: SFTP disabled")
if err := sess.Exit(1); err != nil {
- log.Debugf("SFTP session exit failed: %v", err)
+ logger.Debugf("SFTP session exit: %v", err)
}
return
}
@@ -37,31 +49,27 @@ func (s *Server) sftpSubsystemHandler(sess ssh.Session) {
})
if !result.Allowed {
- log.Warnf("SFTP access denied for user %s from %s: %v", sess.User(), sess.RemoteAddr(), result.Error)
+ logger.Warnf("SFTP access denied: %v", result.Error)
if err := sess.Exit(1); err != nil {
- log.Debugf("exit SFTP session: %v", err)
+ logger.Debugf("exit SFTP session: %v", err)
}
return
}
- log.Debugf("SFTP subsystem request from user %s (effective user %s)", sess.User(), result.User.Username)
-
if !result.RequiresUserSwitching {
if err := s.executeSftpDirect(sess); err != nil {
- log.Errorf("SFTP direct execution: %v", err)
+ logger.Errorf("SFTP direct execution: %v", err)
}
return
}
if err := s.executeSftpWithPrivilegeDrop(sess, result.User); err != nil {
- log.Errorf("SFTP privilege drop execution: %v", err)
+ logger.Errorf("SFTP privilege drop execution: %v", err)
}
}
// executeSftpDirect executes SFTP directly without privilege dropping
func (s *Server) executeSftpDirect(sess ssh.Session) error {
- log.Debugf("starting SFTP session for user %s (no privilege dropping)", sess.User())
-
sftpServer, err := sftp.NewServer(sess)
if err != nil {
return fmt.Errorf("SFTP server creation: %w", err)
diff --git a/client/ssh/server/test.go b/client/ssh/server/test.go
index 20930c721..f8abd1752 100644
--- a/client/ssh/server/test.go
+++ b/client/ssh/server/test.go
@@ -3,7 +3,6 @@ package server
import (
"context"
"fmt"
- "net"
"net/netip"
"testing"
"time"
@@ -14,23 +13,21 @@ func StartTestServer(t *testing.T, server *Server) string {
errChan := make(chan error, 1)
go func() {
- ln, err := net.Listen("tcp", "127.0.0.1:0")
- if err != nil {
- errChan <- err
- return
- }
- actualAddr := ln.Addr().String()
- if err := ln.Close(); err != nil {
- errChan <- fmt.Errorf("close temp listener: %w", err)
- return
- }
-
- addrPort := netip.MustParseAddrPort(actualAddr)
+ // Use port 0 to let the OS assign a free port
+ addrPort := netip.MustParseAddrPort("127.0.0.1:0")
if err := server.Start(context.Background(), addrPort); err != nil {
errChan <- err
return
}
- started <- actualAddr
+
+ // Get the actual listening address from the server
+ actualAddr := server.Addr()
+ if actualAddr == nil {
+ errChan <- fmt.Errorf("server started but no listener address available")
+ return
+ }
+
+ started <- actualAddr.String()
}()
select {
diff --git a/client/status/status.go b/client/status/status.go
index d975f0e29..4f31f3637 100644
--- a/client/status/status.go
+++ b/client/status/status.go
@@ -82,10 +82,11 @@ type NsServerGroupStateOutput struct {
}
type SSHSessionOutput struct {
- Username string `json:"username" yaml:"username"`
- RemoteAddress string `json:"remoteAddress" yaml:"remoteAddress"`
- Command string `json:"command" yaml:"command"`
- JWTUsername string `json:"jwtUsername,omitempty" yaml:"jwtUsername,omitempty"`
+ Username string `json:"username" yaml:"username"`
+ RemoteAddress string `json:"remoteAddress" yaml:"remoteAddress"`
+ Command string `json:"command" yaml:"command"`
+ JWTUsername string `json:"jwtUsername,omitempty" yaml:"jwtUsername,omitempty"`
+ PortForwards []string `json:"portForwards,omitempty" yaml:"portForwards,omitempty"`
}
type SSHServerStateOutput struct {
@@ -220,6 +221,7 @@ func mapSSHServer(sshServerState *proto.SSHServerState) SSHServerStateOutput {
RemoteAddress: session.GetRemoteAddress(),
Command: session.GetCommand(),
JWTUsername: session.GetJwtUsername(),
+ PortForwards: session.GetPortForwards(),
})
}
@@ -475,6 +477,9 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
)
}
sshServerStatus += "\n " + sessionDisplay
+ for _, pf := range session.PortForwards {
+ sshServerStatus += "\n " + pf
+ }
}
}
}
diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go
index 8f99608e7..78934ea95 100644
--- a/client/ui/client_ui.go
+++ b/client/ui/client_ui.go
@@ -34,6 +34,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
+ protobuf "google.golang.org/protobuf/proto"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal"
@@ -43,7 +44,6 @@ import (
"github.com/netbirdio/netbird/client/ui/desktop"
"github.com/netbirdio/netbird/client/ui/event"
"github.com/netbirdio/netbird/client/ui/process"
-
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
@@ -87,22 +87,24 @@ func main() {
// Create the service client (this also builds the settings or networks UI if requested).
client := newServiceClient(&newServiceClientArgs{
- addr: flags.daemonAddr,
- logFile: logFile,
- app: a,
- showSettings: flags.showSettings,
- showNetworks: flags.showNetworks,
- showLoginURL: flags.showLoginURL,
- showDebug: flags.showDebug,
- showProfiles: flags.showProfiles,
- showQuickActions: flags.showQuickActions,
+ addr: flags.daemonAddr,
+ logFile: logFile,
+ app: a,
+ showSettings: flags.showSettings,
+ showNetworks: flags.showNetworks,
+ showLoginURL: flags.showLoginURL,
+ showDebug: flags.showDebug,
+ showProfiles: flags.showProfiles,
+ showQuickActions: flags.showQuickActions,
+ showUpdate: flags.showUpdate,
+ showUpdateVersion: flags.showUpdateVersion,
})
// Watch for theme/settings changes to update the icon.
go watchSettingsChanges(a, client)
// Run in window mode if any UI flag was set.
- if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles || flags.showQuickActions {
+ if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles || flags.showQuickActions || flags.showUpdate {
a.Run()
return
}
@@ -128,15 +130,17 @@ func main() {
}
type cliFlags struct {
- daemonAddr string
- showSettings bool
- showNetworks bool
- showProfiles bool
- showDebug bool
- showLoginURL bool
- showQuickActions bool
- errorMsg string
- saveLogsInFile bool
+ daemonAddr string
+ showSettings bool
+ showNetworks bool
+ showProfiles bool
+ showDebug bool
+ showLoginURL bool
+ showQuickActions bool
+ errorMsg string
+ saveLogsInFile bool
+ showUpdate bool
+ showUpdateVersion string
}
// parseFlags reads and returns all needed command-line flags.
@@ -156,6 +160,8 @@ func parseFlags() *cliFlags {
flag.StringVar(&flags.errorMsg, "error-msg", "", "displays an error message window")
flag.BoolVar(&flags.saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", os.TempDir()))
flag.BoolVar(&flags.showLoginURL, "login-url", false, "show login URL in a popup window")
+ flag.BoolVar(&flags.showUpdate, "update", false, "show update progress window")
+ flag.StringVar(&flags.showUpdateVersion, "update-version", "", "version to update to")
flag.Parse()
return &flags
}
@@ -306,6 +312,8 @@ type serviceClient struct {
daemonVersion string
updateIndicationLock sync.Mutex
isUpdateIconActive bool
+ settingsEnabled bool
+ profilesEnabled bool
showNetworks bool
wNetworks fyne.Window
wProfiles fyne.Window
@@ -319,6 +327,8 @@ type serviceClient struct {
mExitNodeDeselectAll *systray.MenuItem
logFile string
wLoginURL fyne.Window
+ wUpdateProgress fyne.Window
+ updateContextCancel context.CancelFunc
connectCancel context.CancelFunc
}
@@ -329,15 +339,17 @@ type menuHandler struct {
}
type newServiceClientArgs struct {
- addr string
- logFile string
- app fyne.App
- showSettings bool
- showNetworks bool
- showDebug bool
- showLoginURL bool
- showProfiles bool
- showQuickActions bool
+ addr string
+ logFile string
+ app fyne.App
+ showSettings bool
+ showNetworks bool
+ showDebug bool
+ showLoginURL bool
+ showProfiles bool
+ showQuickActions bool
+ showUpdate bool
+ showUpdateVersion string
}
// newServiceClient instance constructor
@@ -355,7 +367,7 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
showAdvancedSettings: args.showSettings,
showNetworks: args.showNetworks,
- update: version.NewUpdate("nb/client-ui"),
+ update: version.NewUpdateAndStart("nb/client-ui"),
}
s.eventHandler = newEventHandler(s)
@@ -375,6 +387,8 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
s.showProfilesUI()
case args.showQuickActions:
s.showQuickActionsUI()
+ case args.showUpdate:
+ s.showUpdateProgress(ctx, args.showUpdateVersion)
}
return s
@@ -814,7 +828,7 @@ func (s *serviceClient) handleSSOLogin(ctx context.Context, loginResp *proto.Log
return nil
}
-func (s *serviceClient) menuUpClick(ctx context.Context) error {
+func (s *serviceClient) menuUpClick(ctx context.Context, wannaAutoUpdate bool) error {
systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting)
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
@@ -836,7 +850,9 @@ func (s *serviceClient) menuUpClick(ctx context.Context) error {
return nil
}
- if _, err := conn.Up(ctx, &proto.UpRequest{}); err != nil {
+ if _, err := s.conn.Up(s.ctx, &proto.UpRequest{
+ AutoUpdate: protobuf.Bool(wannaAutoUpdate),
+ }); err != nil {
return fmt.Errorf("start connection: %w", err)
}
@@ -893,7 +909,7 @@ func (s *serviceClient) updateStatus() error {
var systrayIconState bool
switch {
- case status.Status == string(internal.StatusConnected):
+ case status.Status == string(internal.StatusConnected) && !s.connected:
s.connected = true
s.sendNotification = true
if s.isUpdateIconActive {
@@ -907,6 +923,7 @@ func (s *serviceClient) updateStatus() error {
s.mUp.Disable()
s.mDown.Enable()
s.mNetworks.Enable()
+ s.mExitNode.Enable()
go s.updateExitNodes()
systrayIconState = true
case status.Status == string(internal.StatusConnecting):
@@ -1097,6 +1114,26 @@ func (s *serviceClient) onTrayReady() {
s.updateExitNodes()
}
})
+ s.eventManager.AddHandler(func(event *proto.SystemEvent) {
+ // todo use new Category
+ if windowAction, ok := event.Metadata["progress_window"]; ok {
+ targetVersion, ok := event.Metadata["version"]
+ if !ok {
+ targetVersion = "unknown"
+ }
+ log.Debugf("window action: %v", windowAction)
+ if windowAction == "show" {
+ if s.updateContextCancel != nil {
+ s.updateContextCancel()
+ s.updateContextCancel = nil
+ }
+
+ subCtx, cancel := context.WithCancel(s.ctx)
+ go s.eventHandler.runSelfCommand(subCtx, "update", "--update-version", targetVersion)
+ s.updateContextCancel = cancel
+ }
+ }
+ })
go s.eventManager.Start(s.ctx)
go s.eventHandler.listen(s.ctx)
@@ -1240,19 +1277,22 @@ func (s *serviceClient) checkAndUpdateFeatures() {
return
}
+ s.updateIndicationLock.Lock()
+ defer s.updateIndicationLock.Unlock()
+
// Update settings menu based on current features
- if features != nil && features.DisableUpdateSettings {
- s.setSettingsEnabled(false)
- } else {
- s.setSettingsEnabled(true)
+ settingsEnabled := features == nil || !features.DisableUpdateSettings
+ if s.settingsEnabled != settingsEnabled {
+ s.settingsEnabled = settingsEnabled
+ s.setSettingsEnabled(settingsEnabled)
}
// Update profile menu based on current features
if s.mProfile != nil {
- if features != nil && features.DisableProfiles {
- s.mProfile.setEnabled(false)
- } else {
- s.mProfile.setEnabled(true)
+ profilesEnabled := features == nil || !features.DisableProfiles
+ if s.profilesEnabled != profilesEnabled {
+ s.profilesEnabled = profilesEnabled
+ s.mProfile.setEnabled(profilesEnabled)
}
}
}
diff --git a/client/ui/event_handler.go b/client/ui/event_handler.go
index e0b619411..9ffacd926 100644
--- a/client/ui/event_handler.go
+++ b/client/ui/event_handler.go
@@ -80,7 +80,7 @@ func (h *eventHandler) handleConnectClick() {
go func() {
defer connectCancel()
- if err := h.client.menuUpClick(connectCtx); err != nil {
+ if err := h.client.menuUpClick(connectCtx, true); err != nil {
st, ok := status.FromError(err)
if errors.Is(err, context.Canceled) || (ok && st.Code() == codes.Canceled) {
log.Debugf("connect operation cancelled by user")
@@ -185,7 +185,7 @@ func (h *eventHandler) handleAdvancedSettingsClick() {
go func() {
defer h.client.mAdvancedSettings.Enable()
defer h.client.getSrvConfig()
- h.runSelfCommand(h.client.ctx, "settings", "true")
+ h.runSelfCommand(h.client.ctx, "settings")
}()
}
@@ -193,7 +193,7 @@ func (h *eventHandler) handleCreateDebugBundleClick() {
h.client.mCreateDebugBundle.Disable()
go func() {
defer h.client.mCreateDebugBundle.Enable()
- h.runSelfCommand(h.client.ctx, "debug", "true")
+ h.runSelfCommand(h.client.ctx, "debug")
}()
}
@@ -217,7 +217,7 @@ func (h *eventHandler) handleNetworksClick() {
h.client.mNetworks.Disable()
go func() {
defer h.client.mNetworks.Enable()
- h.runSelfCommand(h.client.ctx, "networks", "true")
+ h.runSelfCommand(h.client.ctx, "networks")
}()
}
@@ -237,17 +237,21 @@ func (h *eventHandler) updateConfigWithErr() error {
return nil
}
-func (h *eventHandler) runSelfCommand(ctx context.Context, command, arg string) {
+func (h *eventHandler) runSelfCommand(ctx context.Context, command string, args ...string) {
proc, err := os.Executable()
if err != nil {
log.Errorf("error getting executable path: %v", err)
return
}
- cmd := exec.CommandContext(ctx, proc,
- fmt.Sprintf("--%s=%s", command, arg),
+ // Build the full command arguments
+ cmdArgs := []string{
+ fmt.Sprintf("--%s=true", command),
fmt.Sprintf("--daemon-addr=%s", h.client.addr),
- )
+ }
+ cmdArgs = append(cmdArgs, args...)
+
+ cmd := exec.CommandContext(ctx, proc, cmdArgs...)
if out := h.client.attachOutput(cmd); out != nil {
defer func() {
@@ -257,17 +261,17 @@ func (h *eventHandler) runSelfCommand(ctx context.Context, command, arg string)
}()
}
- log.Printf("running command: %s --%s=%s --daemon-addr=%s", proc, command, arg, h.client.addr)
+ log.Printf("running command: %s", cmd.String())
if err := cmd.Run(); err != nil {
var exitErr *exec.ExitError
if errors.As(err, &exitErr) {
- log.Printf("command '%s %s' failed with exit code %d", command, arg, exitErr.ExitCode())
+ log.Printf("command '%s' failed with exit code %d", cmd.String(), exitErr.ExitCode())
}
return
}
- log.Printf("command '%s %s' completed successfully", command, arg)
+ log.Printf("command '%s' completed successfully", cmd.String())
}
func (h *eventHandler) logout(ctx context.Context) error {
diff --git a/client/ui/font_windows.go b/client/ui/font_windows.go
index 93b23a21b..6346a9fb9 100644
--- a/client/ui/font_windows.go
+++ b/client/ui/font_windows.go
@@ -31,7 +31,6 @@ func (s *serviceClient) getWindowsFontFilePath() string {
"chr-CHER-US": "Gadugi.ttf",
"zh-HK": "Segoeui.ttf",
"zh-TW": "Segoeui.ttf",
- "ja-JP": "Yugothm.ttc",
"km-KH": "Leelawui.ttf",
"ko-KR": "Malgun.ttf",
"th-TH": "Leelawui.ttf",
diff --git a/client/ui/profile.go b/client/ui/profile.go
index 74189c9a0..a38d8918a 100644
--- a/client/ui/profile.go
+++ b/client/ui/profile.go
@@ -397,7 +397,7 @@ type profileMenu struct {
logoutSubItem *subItem
profilesState []Profile
downClickCallback func() error
- upClickCallback func(context.Context) error
+ upClickCallback func(context.Context, bool) error
getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error)
loadSettingsCallback func()
app fyne.App
@@ -411,7 +411,7 @@ type newProfileMenuArgs struct {
profileMenuItem *systray.MenuItem
emailMenuItem *systray.MenuItem
downClickCallback func() error
- upClickCallback func(context.Context) error
+ upClickCallback func(context.Context, bool) error
getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error)
loadSettingsCallback func()
app fyne.App
@@ -579,7 +579,7 @@ func (p *profileMenu) refresh() {
connectCtx, connectCancel := context.WithCancel(p.ctx)
p.serviceClient.connectCancel = connectCancel
- if err := p.upClickCallback(connectCtx); err != nil {
+ if err := p.upClickCallback(connectCtx, false); err != nil {
log.Errorf("failed to handle up click after switching profile: %v", err)
}
diff --git a/client/ui/quickactions.go b/client/ui/quickactions.go
index bf47ac434..76440d684 100644
--- a/client/ui/quickactions.go
+++ b/client/ui/quickactions.go
@@ -267,7 +267,7 @@ func (s *serviceClient) showQuickActionsUI() {
connCmd := connectCommand{
connectClient: func() error {
- return s.menuUpClick(s.ctx)
+ return s.menuUpClick(s.ctx, false)
},
}
diff --git a/client/ui/update.go b/client/ui/update.go
new file mode 100644
index 000000000..25c317bdf
--- /dev/null
+++ b/client/ui/update.go
@@ -0,0 +1,140 @@
+//go:build !(linux && 386)
+
+package main
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "strings"
+ "time"
+
+ "fyne.io/fyne/v2/container"
+ "fyne.io/fyne/v2/widget"
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/client/proto"
+)
+
+func (s *serviceClient) showUpdateProgress(ctx context.Context, version string) {
+ log.Infof("show installer progress window: %s", version)
+ s.wUpdateProgress = s.app.NewWindow("Automatically updating client")
+
+ statusLabel := widget.NewLabel("Updating...")
+ infoLabel := widget.NewLabel(fmt.Sprintf("Your client version is older than the auto-update version set in Management.\nUpdating client to: %s.", version))
+ content := container.NewVBox(infoLabel, statusLabel)
+ s.wUpdateProgress.SetContent(content)
+ s.wUpdateProgress.CenterOnScreen()
+ s.wUpdateProgress.SetFixedSize(true)
+ s.wUpdateProgress.SetCloseIntercept(func() {
+ // this is empty to lock window until result known
+ })
+ s.wUpdateProgress.RequestFocus()
+ s.wUpdateProgress.Show()
+
+ updateWindowCtx, cancel := context.WithTimeout(ctx, 15*time.Minute)
+
+ // Initialize dot updater
+ updateText := dotUpdater()
+
+ // Channel to receive the result from RPC call
+ resultErrCh := make(chan error, 1)
+ resultOkCh := make(chan struct{}, 1)
+
+ // Start RPC call in background
+ go func() {
+ conn, err := s.getSrvClient(defaultFailTimeout)
+ if err != nil {
+ log.Infof("backend not reachable, upgrade in progress: %v", err)
+ close(resultOkCh)
+ return
+ }
+
+ resp, err := conn.GetInstallerResult(updateWindowCtx, &proto.InstallerResultRequest{})
+ if err != nil {
+ log.Infof("backend stopped responding, upgrade in progress: %v", err)
+ close(resultOkCh)
+ return
+ }
+
+ if !resp.Success {
+ resultErrCh <- mapInstallError(resp.ErrorMsg)
+ return
+ }
+
+ // Success
+ close(resultOkCh)
+ }()
+
+ // Update UI with dots and wait for result
+ go func() {
+ ticker := time.NewTicker(time.Second)
+ defer ticker.Stop()
+ defer cancel()
+
+ // allow closing update window after 10 sec
+ timerResetCloseInterceptor := time.NewTimer(10 * time.Second)
+ defer timerResetCloseInterceptor.Stop()
+
+ for {
+ select {
+ case <-updateWindowCtx.Done():
+ s.showInstallerResult(statusLabel, updateWindowCtx.Err())
+ return
+ case err := <-resultErrCh:
+ s.showInstallerResult(statusLabel, err)
+ return
+ case <-resultOkCh:
+ log.Info("backend exited, upgrade in progress, closing all UI")
+ killParentUIProcess()
+ s.app.Quit()
+ return
+ case <-ticker.C:
+ statusLabel.SetText(updateText())
+ case <-timerResetCloseInterceptor.C:
+ s.wUpdateProgress.SetCloseIntercept(nil)
+ }
+ }
+ }()
+}
+
+func (s *serviceClient) showInstallerResult(statusLabel *widget.Label, err error) {
+ s.wUpdateProgress.SetCloseIntercept(nil)
+ switch {
+ case errors.Is(err, context.DeadlineExceeded):
+ log.Warn("update watcher timed out")
+ statusLabel.SetText("Update timed out. Please try again.")
+ case errors.Is(err, context.Canceled):
+ log.Info("update watcher canceled")
+ statusLabel.SetText("Update canceled.")
+ case err != nil:
+ log.Errorf("update failed: %v", err)
+ statusLabel.SetText("Update failed: " + err.Error())
+ default:
+ s.wUpdateProgress.Close()
+ }
+}
+
+// dotUpdater returns a closure that cycles through dots for a loading animation.
+func dotUpdater() func() string {
+ dotCount := 0
+ return func() string {
+ dotCount = (dotCount + 1) % 4
+ return fmt.Sprintf("%s%s", "Updating", strings.Repeat(".", dotCount))
+ }
+}
+
+func mapInstallError(msg string) error {
+ msg = strings.ToLower(strings.TrimSpace(msg))
+
+ switch {
+ case strings.Contains(msg, "deadline exceeded"), strings.Contains(msg, "timeout"):
+ return context.DeadlineExceeded
+ case strings.Contains(msg, "canceled"), strings.Contains(msg, "cancelled"):
+ return context.Canceled
+ case msg == "":
+ return errors.New("unknown update error")
+ default:
+ return errors.New(msg)
+ }
+}
diff --git a/client/ui/update_notwindows.go b/client/ui/update_notwindows.go
new file mode 100644
index 000000000..5766f18f7
--- /dev/null
+++ b/client/ui/update_notwindows.go
@@ -0,0 +1,7 @@
+//go:build !windows && !(linux && 386)
+
+package main
+
+func killParentUIProcess() {
+ // No-op on non-Windows platforms
+}
diff --git a/client/ui/update_windows.go b/client/ui/update_windows.go
new file mode 100644
index 000000000..1b03936f9
--- /dev/null
+++ b/client/ui/update_windows.go
@@ -0,0 +1,44 @@
+//go:build windows
+
+package main
+
+import (
+ log "github.com/sirupsen/logrus"
+ "golang.org/x/sys/windows"
+
+ nbprocess "github.com/netbirdio/netbird/client/ui/process"
+)
+
+// killParentUIProcess finds and kills the parent systray UI process on Windows.
+// This is a workaround in case the MSI installer fails to properly terminate the UI process.
+// The installer should handle this via util:CloseApplication with TerminateProcess, but this
+// provides an additional safety mechanism to ensure the UI is closed before the upgrade proceeds.
+func killParentUIProcess() {
+ pid, running, err := nbprocess.IsAnotherProcessRunning()
+ if err != nil {
+ log.Warnf("failed to check for parent UI process: %v", err)
+ return
+ }
+
+ if !running {
+ log.Debug("no parent UI process found to kill")
+ return
+ }
+
+ log.Infof("killing parent UI process (PID: %d)", pid)
+
+ // Open the process with terminate rights
+ handle, err := windows.OpenProcess(windows.PROCESS_TERMINATE, false, uint32(pid))
+ if err != nil {
+ log.Warnf("failed to open parent process %d: %v", pid, err)
+ return
+ }
+ defer func() {
+ _ = windows.CloseHandle(handle)
+ }()
+
+ // Terminate the process with exit code 0
+ if err := windows.TerminateProcess(handle, 0); err != nil {
+ log.Warnf("failed to terminate parent process %d: %v", pid, err)
+ }
+}
diff --git a/go.mod b/go.mod
index 8f4ec530b..23cf0f37d 100644
--- a/go.mod
+++ b/go.mod
@@ -8,22 +8,22 @@ require (
github.com/cloudflare/circl v1.3.3 // indirect
github.com/golang/protobuf v1.5.4
github.com/google/uuid v1.6.0
- github.com/gorilla/mux v1.8.0
+ github.com/gorilla/mux v1.8.1
github.com/kardianos/service v1.2.3-0.20240613133416-becf2eb62b83
github.com/onsi/ginkgo v1.16.5
github.com/onsi/gomega v1.27.6
github.com/rs/cors v1.8.0
github.com/sirupsen/logrus v1.9.3
- github.com/spf13/cobra v1.7.0
- github.com/spf13/pflag v1.0.5
+ github.com/spf13/cobra v1.10.1
+ github.com/spf13/pflag v1.0.9
github.com/vishvananda/netlink v1.3.1
- golang.org/x/crypto v0.45.0
- golang.org/x/sys v0.38.0
+ golang.org/x/crypto v0.46.0
+ golang.org/x/sys v0.39.0
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard/windows v0.5.3
- google.golang.org/grpc v1.73.0
- google.golang.org/protobuf v1.36.8
+ google.golang.org/grpc v1.77.0
+ google.golang.org/protobuf v1.36.10
gopkg.in/natefinch/lumberjack.v2 v2.0.0
)
@@ -41,6 +41,8 @@ require (
github.com/coder/websocket v1.8.13
github.com/coreos/go-iptables v0.7.0
github.com/creack/pty v1.1.18
+ github.com/dexidp/dex v0.0.0-00010101000000-000000000000
+ github.com/dexidp/dex/api/v2 v2.4.0
github.com/eko/gocache/lib/v4 v4.2.0
github.com/eko/gocache/store/go_cache/v4 v4.2.2
github.com/eko/gocache/store/redis/v4 v4.2.2
@@ -78,7 +80,7 @@ require (
github.com/pion/transport/v3 v3.0.7
github.com/pion/turn/v3 v3.0.1
github.com/pkg/sftp v1.13.9
- github.com/prometheus/client_golang v1.22.0
+ github.com/prometheus/client_golang v1.23.2
github.com/quic-go/quic-go v0.49.1
github.com/redis/go-redis/v9 v9.7.3
github.com/rs/xid v1.3.0
@@ -96,11 +98,11 @@ require (
github.com/vmihailenco/msgpack/v5 v5.4.1
github.com/yusufpapurcu/wmi v1.2.4
github.com/zcalusic/sysinfo v1.1.3
- go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0
- go.opentelemetry.io/otel v1.35.0
+ go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0
+ go.opentelemetry.io/otel v1.38.0
go.opentelemetry.io/otel/exporters/prometheus v0.48.0
- go.opentelemetry.io/otel/metric v1.35.0
- go.opentelemetry.io/otel/sdk/metric v1.35.0
+ go.opentelemetry.io/otel/metric v1.38.0
+ go.opentelemetry.io/otel/sdk/metric v1.38.0
go.uber.org/mock v0.5.0
go.uber.org/zap v1.27.0
goauthentik.io/api/v3 v3.2023051.3
@@ -108,11 +110,11 @@ require (
golang.org/x/mobile v0.0.0-20251113184115-a159579294ab
golang.org/x/mod v0.30.0
golang.org/x/net v0.47.0
- golang.org/x/oauth2 v0.30.0
- golang.org/x/sync v0.18.0
- golang.org/x/term v0.37.0
- golang.org/x/time v0.12.0
- google.golang.org/api v0.177.0
+ golang.org/x/oauth2 v0.34.0
+ golang.org/x/sync v0.19.0
+ golang.org/x/term v0.38.0
+ golang.org/x/time v0.14.0
+ google.golang.org/api v0.257.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/mysql v1.5.7
gorm.io/driver/postgres v1.5.7
@@ -122,13 +124,18 @@ require (
)
require (
- cloud.google.com/go/auth v0.3.0 // indirect
- cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect
- cloud.google.com/go/compute/metadata v0.6.0 // indirect
- dario.cat/mergo v1.0.0 // indirect
+ cloud.google.com/go/auth v0.17.0 // indirect
+ cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
+ cloud.google.com/go/compute/metadata v0.9.0 // indirect
+ dario.cat/mergo v1.0.1 // indirect
filippo.io/edwards25519 v1.1.0 // indirect
+ github.com/AppsFlyer/go-sundheit v0.6.0 // indirect
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
+ github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect
github.com/BurntSushi/toml v1.5.0 // indirect
+ github.com/Masterminds/goutils v1.1.1 // indirect
+ github.com/Masterminds/semver/v3 v3.3.0 // indirect
+ github.com/Masterminds/sprig/v3 v3.3.0 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/Microsoft/hcsshim v0.12.3 // indirect
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
@@ -149,12 +156,14 @@ require (
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 // indirect
github.com/aws/smithy-go v1.22.2 // indirect
+ github.com/beevik/etree v1.6.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/caddyserver/zerossl v0.1.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/containerd/containerd v1.7.29 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/containerd/platforms v0.2.1 // indirect
+ github.com/coreos/go-oidc/v3 v3.14.1 // indirect
github.com/cpuguy83/dockercfg v0.3.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
@@ -168,26 +177,30 @@ require (
github.com/fyne-io/glfw-js v0.3.0 // indirect
github.com/fyne-io/image v0.1.1 // indirect
github.com/fyne-io/oksvg v0.2.0 // indirect
+ github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect
github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 // indirect
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect
- github.com/go-logr/logr v1.4.2 // indirect
+ github.com/go-jose/go-jose/v4 v4.1.3 // indirect
+ github.com/go-ldap/ldap/v3 v3.4.12 // indirect
+ github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.3.0 // indirect
- github.com/go-sql-driver/mysql v1.8.1 // indirect
+ github.com/go-sql-driver/mysql v1.9.3 // indirect
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/go-text/render v0.2.0 // indirect
github.com/go-text/typesetting v0.2.1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
- github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/google/btree v1.1.2 // indirect
github.com/google/pprof v0.0.0-20211214055906-6f57359322fd // indirect
- github.com/google/s2a-go v0.1.7 // indirect
- github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
- github.com/googleapis/gax-go/v2 v2.12.3 // indirect
+ github.com/google/s2a-go v0.1.9 // indirect
+ github.com/googleapis/enterprise-certificate-proxy v0.3.7 // indirect
+ github.com/googleapis/gax-go/v2 v2.15.0 // indirect
+ github.com/gorilla/handlers v1.5.2 // indirect
github.com/hack-pad/go-indexeddb v0.3.2 // indirect
github.com/hack-pad/safejs v0.1.0 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-uuid v1.0.3 // indirect
+ github.com/huandu/xstrings v1.5.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
@@ -196,18 +209,23 @@ require (
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
+ github.com/jonboulle/clockwork v0.5.0 // indirect
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 // indirect
github.com/kelseyhightower/envconfig v1.4.0 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
github.com/kr/fs v0.1.0 // indirect
+ github.com/lib/pq v1.10.9 // indirect
github.com/libdns/libdns v0.2.2 // indirect
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect
github.com/magiconair/properties v1.8.7 // indirect
- github.com/mattn/go-sqlite3 v1.14.22 // indirect
+ github.com/mattermost/xml-roundtrip-validator v0.1.0 // indirect
+ github.com/mattn/go-sqlite3 v1.14.32 // indirect
github.com/mdlayher/genetlink v1.3.2 // indirect
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 // indirect
github.com/mholt/acmez/v2 v2.0.1 // indirect
+ github.com/mitchellh/copystructure v1.2.0 // indirect
+ github.com/mitchellh/reflectwalk v1.0.2 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/moby/patternmatcher v0.6.0 // indirect
github.com/moby/sys/sequential v0.5.0 // indirect
@@ -230,11 +248,14 @@ require (
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
- github.com/prometheus/client_model v0.6.1 // indirect
- github.com/prometheus/common v0.62.0 // indirect
- github.com/prometheus/procfs v0.15.1 // indirect
+ github.com/prometheus/client_model v0.6.2 // indirect
+ github.com/prometheus/common v0.66.1 // indirect
+ github.com/prometheus/procfs v0.16.1 // indirect
+ github.com/russellhaering/goxmldsig v1.5.0 // indirect
github.com/rymdport/portal v0.4.2 // indirect
github.com/shoenig/go-m1cpu v0.1.6 // indirect
+ github.com/shopspring/decimal v1.4.0 // indirect
+ github.com/spf13/cast v1.7.0 // indirect
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect
github.com/stretchr/objx v0.5.2 // indirect
@@ -245,17 +266,17 @@ require (
github.com/wlynxg/anet v0.0.3 // indirect
github.com/yuin/goldmark v1.7.8 // indirect
github.com/zeebo/blake3 v0.2.3 // indirect
- go.opencensus.io v0.24.0 // indirect
- go.opentelemetry.io/auto/sdk v1.1.0 // indirect
- go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
- go.opentelemetry.io/otel/sdk v1.35.0 // indirect
- go.opentelemetry.io/otel/trace v1.35.0 // indirect
+ go.opentelemetry.io/auto/sdk v1.2.1 // indirect
+ go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 // indirect
+ go.opentelemetry.io/otel/sdk v1.38.0 // indirect
+ go.opentelemetry.io/otel/trace v1.38.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
+ go.yaml.in/yaml/v2 v2.4.2 // indirect
golang.org/x/image v0.33.0 // indirect
- golang.org/x/text v0.31.0 // indirect
+ golang.org/x/text v0.32.0 // indirect
golang.org/x/tools v0.39.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
- google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect
+ google.golang.org/genproto/googleapis/rpc v0.0.0-20251124214823-79d6a2a48846 // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
)
@@ -271,3 +292,5 @@ replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-2023080111
replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51
replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944
+
+replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.0
diff --git a/go.sum b/go.sum
index f10e1e6da..354c7732e 100644
--- a/go.sum
+++ b/go.sum
@@ -1,15 +1,14 @@
-cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
-cloud.google.com/go/auth v0.3.0 h1:PRyzEpGfx/Z9e8+lHsbkoUVXD0gnu4MNmm7Gp8TQNIs=
-cloud.google.com/go/auth v0.3.0/go.mod h1:lBv6NKTWp8E3LPzmO1TbiiRKc4drLOfHsgmlH9ogv5w=
-cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKFt//oWu7HX4=
-cloud.google.com/go/auth/oauth2adapt v0.2.2/go.mod h1:wcYjgpZI9+Yu7LyYBg4pqSiaRkfEK3GQcpb7C/uyF1Q=
+cloud.google.com/go/auth v0.17.0 h1:74yCm7hCj2rUyyAocqnFzsAYXgJhrG26XCFimrc/Kz4=
+cloud.google.com/go/auth v0.17.0/go.mod h1:6wv/t5/6rOPAX4fJiRjKkJCvswLwdet7G8+UGXt7nCQ=
+cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc=
+cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c=
cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
-cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I=
-cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg=
+cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs=
+cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
cunicu.li/go-rosenpass v0.4.0 h1:LtPtBgFWY/9emfgC4glKLEqS0MJTylzV6+ChRhiZERw=
cunicu.li/go-rosenpass v0.4.0/go.mod h1:MPbjH9nxV4l3vEagKVdFNwHOketqgS5/To1VYJplf/M=
-dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk=
-dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
+dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
+dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
fyne.io/fyne/v2 v2.7.0 h1:GvZSpE3X0liU/fqstInVvRsaboIVpIWQ4/sfjDGIGGQ=
@@ -18,17 +17,28 @@ fyne.io/systray v1.11.1-0.20250603113521-ca66a66d8b58 h1:eA5/u2XRd8OUkoMqEv3IBlF
fyne.io/systray v1.11.1-0.20250603113521-ca66a66d8b58/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs=
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9vkmnHYOMsOr4WLk+Vo07yKIzd94sVoIqshQ4bU=
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
+github.com/AppsFlyer/go-sundheit v0.6.0 h1:d2hBvCjBSb2lUsEWGfPigr4MCOt04sxB+Rppl0yUMSk=
+github.com/AppsFlyer/go-sundheit v0.6.0/go.mod h1:LDdBHD6tQBtmHsdW+i1GwdTt6Wqc0qazf5ZEJVTbTME=
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0=
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
-github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
+github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 h1:mFRzDkZVAjdal+s7s0MwaRv9igoPqLRdzOLzw/8Xvq8=
+github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358/go.mod h1:chxPXzSsl7ZWRAuOIE23GDNzjWuZquvFlgA8xmpunjU=
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
+github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI=
+github.com/Masterminds/goutils v1.1.1/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU=
+github.com/Masterminds/semver/v3 v3.3.0 h1:B8LGeaivUe71a5qox1ICM/JLl0NqZSW5CHyL+hmvYS0=
+github.com/Masterminds/semver/v3 v3.3.0/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM=
+github.com/Masterminds/sprig/v3 v3.3.0 h1:mQh0Yrg1XPo6vjYXgtf5OtijNAKJRNcTdOOGZe3tPhs=
+github.com/Masterminds/sprig/v3 v3.3.0/go.mod h1:Zy1iXRYNqNLUolqCpL4uhk6SHUMAOSCzdgBfDb35Lz0=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/Microsoft/hcsshim v0.12.3 h1:LS9NXqXhMoqNCplK1ApmVSfB4UnVLRDWRapB6EIlxE0=
github.com/Microsoft/hcsshim v0.12.3/go.mod h1:Iyl1WVpZzr+UkzjekHZbV8o5Z9ZkxNGx6CtY2Qg/JVQ=
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible h1:hqcTK6ZISdip65SR792lwYJTa/axESA0889D3UlZbLo=
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible/go.mod h1:6B1nuc1MUs6c62ODZDl7hVE5Pv7O2XGSkgg2olnq34I=
+github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e h1:4dAU9FXIyQktpoUAgOJK3OTFc/xug0PCXYCqU0FgDKI=
+github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
github.com/awnumar/memcall v0.4.0 h1:B7hgZYdfH6Ot1Goaz8jGne/7i8xD4taZie/PNSFZ29g=
@@ -73,6 +83,8 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 h1:1XuUZ8mYJw9B6lzAkXhqHlJd/Xv
github.com/aws/aws-sdk-go-v2/service/sts v1.33.19/go.mod h1:cQnB8CUnxbMU82JvlqjKR2HBOm3fe9pWorWBza6MBJ4=
github.com/aws/smithy-go v1.22.2 h1:6D9hW43xKFrRx/tXXfAlIZc4JI+yQe6snnWcQyxSyLQ=
github.com/aws/smithy-go v1.22.2/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg=
+github.com/beevik/etree v1.6.0 h1:u8Kwy8pp9D9XeITj2Z0XtA5qqZEmtJtuXZRQi+j03eE=
+github.com/beevik/etree v1.6.0/go.mod h1:bh4zJxiIr62SOf9pRzN7UUYaEDa9HEKafK25+sLc0Gc=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
@@ -87,7 +99,6 @@ github.com/caddyserver/zerossl v0.1.3 h1:onS+pxp3M8HnHpN5MMbOMyNjmTheJyWRaZYwn+Y
github.com/caddyserver/zerossl v0.1.3/go.mod h1:CxA0acn7oEGO6//4rtrRjYgEoa4MFw/XofZnrYwGqG4=
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
-github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
@@ -95,8 +106,6 @@ github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5P
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk=
github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso=
-github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
-github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE=
github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
github.com/containerd/containerd v1.7.29 h1:90fWABQsaN9mJhGkoVnuzEY+o1XDPbg9BTC9QTAHnuE=
@@ -107,9 +116,11 @@ github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpS
github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw=
github.com/coreos/go-iptables v0.7.0 h1:XWM3V+MPRr5/q51NuWSgU0fqMad64Zyxs8ZUoMsamr8=
github.com/coreos/go-iptables v0.7.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q=
+github.com/coreos/go-oidc/v3 v3.14.1 h1:9ePWwfdwC4QKRlCXsJGou56adA/owXczOzwKdOumLqk=
+github.com/coreos/go-oidc/v3 v3.14.1/go.mod h1:HaZ3szPaZ0e4r6ebqvsLWlk2Tn+aejfmrfah6hnSYEU=
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
-github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
+github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:/DS5cDX3FJdl+XaN2D7XAwFpuanTxnp52DBLZAaJKx0=
@@ -117,6 +128,8 @@ github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70J
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/dexidp/dex/api/v2 v2.4.0 h1:gNba7n6BKVp8X4Jp24cxYn5rIIGhM6kDOXcZoL6tr9A=
+github.com/dexidp/dex/api/v2 v2.4.0/go.mod h1:/p550ADvFFh7K95VmhUD+jgm15VdaNnab9td8DHOpyI=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
@@ -133,14 +146,14 @@ github.com/eko/gocache/store/go_cache/v4 v4.2.2 h1:tAI9nl6TLoJyKG1ujF0CS0n/IgTEM
github.com/eko/gocache/store/go_cache/v4 v4.2.2/go.mod h1:T9zkHokzr8K9EiC7RfMbDg6HSwaV6rv3UdcNu13SGcA=
github.com/eko/gocache/store/redis/v4 v4.2.2 h1:Thw31fzGuH3WzJywsdbMivOmP550D6JS7GDHhvCJPA0=
github.com/eko/gocache/store/redis/v4 v4.2.2/go.mod h1:LaTxLKx9TG/YUEybQvPMij++D7PBTIJ4+pzvk0ykz0w=
-github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
-github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
-github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
-github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/felixge/fgprof v0.9.3 h1:VvyZxILNuCiUCSXtPtYmmtGvb65nqXh2QFWc0Wpf2/g=
github.com/felixge/fgprof v0.9.3/go.mod h1:RdbpDgzqYVh/T9fPELJyV7EYJuHB55UTEULNun8eiPw=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
+github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw=
+github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g=
+github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
+github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fredbi/uri v1.1.1 h1:xZHJC08GZNIUhbP5ImTHnt5Ya0T8FI2VAwI/37kh2Ko=
github.com/fredbi/uri v1.1.1/go.mod h1:4+DZQ5zBjEwQCDmXW5JdIjz0PUA+yJbvtBv+u+adr5o=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
@@ -159,13 +172,19 @@ github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm
github.com/gin-gonic/gin v1.5.0/go.mod h1:Nd6IXA8m5kNZdNEHMBd93KT+mdY3+bewLgRvmCsR2Do=
github.com/gliderlabs/ssh v0.3.8 h1:a4YXD1V7xMF9g5nTkdfnja3Sxy1PVDCj1Zg4Wb8vY6c=
github.com/gliderlabs/ssh v0.3.8/go.mod h1:xYoytBv1sV0aL3CavoDuJIQNURXkkfPA/wxQ1pL1fAU=
+github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 h1:BP4M0CvQ4S3TGls2FvczZtj5Re/2ZzkV9VwqPHH/3Bo=
+github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0=
github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 h1:5BVwOaUSBTlVZowGO6VZGw2H/zl9nrd3eCZfYV+NfQA=
github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71/go.mod h1:9YTyiznxEY1fVinfM7RvRcjRHbw2xLBJ3AAGIT0I4Nw=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a h1:vxnBhFDDT+xzxf1jTJKMKZw3H0swfWk9RpWbBbDK5+0=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
+github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs=
+github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
+github.com/go-ldap/ldap/v3 v3.4.12 h1:1b81mv7MagXZ7+1r7cLTWmyuTqVqdwbtJSjC0DAp9s4=
+github.com/go-ldap/ldap/v3 v3.4.12/go.mod h1:+SPAGcTtOfmGsCb3h1RFiq4xpp4N636G75OEace8lNo=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
-github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
-github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
+github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
+github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
@@ -178,8 +197,8 @@ github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZs
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
-github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
-github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
+github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
+github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
@@ -195,11 +214,6 @@ github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
-github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
-github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
-github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
-github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
-github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
@@ -210,9 +224,7 @@ github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:x
github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs=
github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w=
github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0=
-github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8=
github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
-github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
@@ -220,12 +232,9 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU=
github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
-github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
-github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
-github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
@@ -240,23 +249,24 @@ github.com/google/nftables v0.3.0 h1:bkyZ0cbpVeMHXOrtlFc8ISmfVqq5gPJukoYieyVmITg
github.com/google/nftables v0.3.0/go.mod h1:BCp9FsrbF1Fn/Yu6CLUc9GGZFw/+hsxfluNXXmxBfRM=
github.com/google/pprof v0.0.0-20211214055906-6f57359322fd h1:1FjCyPC+syAzJ5/2S8fqdZK1R22vvA0J7JZKcuOIQ7Y=
github.com/google/pprof v0.0.0-20211214055906-6f57359322fd/go.mod h1:KgnwoLYCZ8IQu3XUZ8Nc/bM9CCZFOyjUNOSygVozoDg=
-github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o=
-github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw=
-github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
+github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0=
+github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
-github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs=
-github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0=
-github.com/googleapis/gax-go/v2 v2.12.3 h1:5/zPPDvw8Q1SuXjrqrZslrqT7dL/uJT2CQii/cLCKqA=
-github.com/googleapis/gax-go/v2 v2.12.3/go.mod h1:AKloxT6GtNbaLm8QTNSidHUVsHYcBHwWRvkNFJUQcS4=
+github.com/googleapis/enterprise-certificate-proxy v0.3.7 h1:zrn2Ee/nWmHulBx5sAVrGgAa0f2/R35S4DJwfFaUPFQ=
+github.com/googleapis/enterprise-certificate-proxy v0.3.7/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA=
+github.com/googleapis/gax-go/v2 v2.15.0 h1:SyjDc1mGgZU5LncH8gimWo9lW1DtIfPibOG81vgd/bo=
+github.com/googleapis/gax-go/v2 v2.15.0/go.mod h1:zVVkkxAQHa1RQpg9z2AUCMnKhi0Qld9rcmyfL1OZhoc=
github.com/gopacket/gopacket v1.1.1 h1:zbx9F9d6A7sWNkFKrvMBZTfGgxFoY4NgUudFVVHMfcw=
github.com/gopacket/gopacket v1.1.1/go.mod h1:HavMeONEl7W9036of9LbSWoonqhH7HA1+ZRO+rMIvFs=
-github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
-github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
+github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE=
+github.com/gorilla/handlers v1.5.2/go.mod h1:dX+xVpaxdSw+q0Qek8SSsl3dfMk3jNddUkMzo0GtH0w=
+github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
+github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 h1:Fkzd8ktnpOR9h47SXHe2AYPwelXLH2GjGsjlAloiWfo=
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357/go.mod h1:w9Y7gY31krpLmrVU5ZPG9H7l9fZuRu5/3R3S3FMtVQ4=
-github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 h1:YBftPWNWd4WwGqtY2yeZL2ef8rHAxPBD8KFhJpmcqms=
-github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0/go.mod h1:YN5jB8ie0yfIUg6VvR9Kz84aCaG7AsGZnLjhHbUqwPg=
+github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 h1:5ZPtiqj0JL5oKWmcsq4VMaAW5ukBEgSGXEN89zeH1Jo=
+github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3/go.mod h1:ndYquD05frm2vACXE1nsccT4oJzjhw2arTS2cpUD1PI=
github.com/hack-pad/go-indexeddb v0.3.2 h1:DTqeJJYc1usa45Q5r52t01KhvlSN02+Oq+tQbSBI91A=
github.com/hack-pad/go-indexeddb v0.3.2/go.mod h1:QvfTevpDVlkfomY498LhstjwbPW6QC4VC/lxYb0Kom0=
github.com/hack-pad/safejs v0.1.0 h1:qPS6vjreAqh2amUqj4WNG1zIw7qlRQJ9K10eDKMCnE8=
@@ -274,6 +284,8 @@ github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/b
github.com/hashicorp/go-version v1.6.0 h1:feTTfFNnjP967rlCxM/I9g701jU+RN74YKx2mOkIeek=
github.com/hashicorp/go-version v1.6.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
+github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=
+github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
github.com/ianlancetaylor/demangle v0.0.0-20210905161508-09a460cdf81d/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
@@ -285,6 +297,18 @@ github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw=
github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
+github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8=
+github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs=
+github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo=
+github.com/jcmturner/dnsutils/v2 v2.0.0/go.mod h1:b0TnjGOvI/n42bZa+hmXL+kFJZsFT7G4t3HTlQ184QM=
+github.com/jcmturner/gofork v1.7.6 h1:QH0l3hzAU1tfT3rZCnW5zXl+orbkNMMRGJfdJjHVETg=
+github.com/jcmturner/gofork v1.7.6/go.mod h1:1622LH6i/EZqLloHfE7IeZ0uEJwMSUyQ/nDd82IeqRo=
+github.com/jcmturner/goidentity/v6 v6.0.1 h1:VKnZd2oEIMorCTsFBnJWbExfNN7yZr3EhJAxwOkZg6o=
+github.com/jcmturner/goidentity/v6 v6.0.1/go.mod h1:X1YW3bgtvwAXju7V3LCIMpY0Gbxyjn/mY9zx4tFonSg=
+github.com/jcmturner/gokrb5/v8 v8.4.4 h1:x1Sv4HaTpepFkXbt2IkL29DXRf8sOfZXo8eRKh687T8=
+github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs=
+github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY=
+github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc=
github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade h1:FmusiCI1wHw+XQbvL9M+1r/C3SPqKrmBaIOYwVfQoDE=
github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade/go.mod h1:ZDXo8KHryOWSIqnsb/CiDq7hQUYryCgdVnxbj8tDG7o=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
@@ -295,6 +319,8 @@ github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9Y
github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8=
github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
+github.com/jonboulle/clockwork v0.5.0 h1:Hyh9A8u51kptdkR+cqRpT1EebBwTn1oK9YfGYbdFz6I=
+github.com/jonboulle/clockwork v0.5.0/go.mod h1:3mZlmanh0g2NDKO5TWZVJAfofYk64M7XN3SzBPjZF60=
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 h1:YLvr1eE6cdCqjOe972w/cYF+FjW34v27+9Vo5106B4M=
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw=
@@ -309,8 +335,11 @@ github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuV
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8=
github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg=
+github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
+github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
+github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
@@ -329,9 +358,11 @@ github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae h1:dIZY4ULFcto4tA
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k=
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
+github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU=
+github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To=
github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
-github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
-github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
+github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
+github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0tpKMbdCXCIfBclan4oHk1Jb+Hrejirg=
@@ -344,8 +375,12 @@ github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs=
github.com/miekg/dns v1.1.59/go.mod h1:nZpewl5p6IvctfgrckopVx2OlSEHPRO/U4SYkRklrEk=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
+github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw=
+github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s=
github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4=
github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE=
+github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zxSIeXaQ=
+github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk=
@@ -364,6 +399,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
+github.com/netbirdio/dex v0.244.0 h1:1GOvi8wnXYassnKGildzNqRHq0RbcfEUw7LKYpKIN7U=
+github.com/netbirdio/dex v0.244.0/go.mod h1:STGInJhPcAflrHmDO7vyit2kSq03PdL+8zQPoGALtcU=
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk=
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
@@ -434,6 +471,7 @@ github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8=
github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE=
github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc=
github.com/pion/turn/v4 v4.1.1/go.mod h1:2123tHk1O++vmjI5VSD0awT50NywDAq5A2NNNU4Jjs8=
+github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/profile v1.7.0 h1:hnbDkaNWPCLMO9wGLdBFTIZvzDrDfBM2072E1S9gJkA=
@@ -445,25 +483,26 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU=
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
-github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q=
-github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0=
-github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
-github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
-github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
-github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
-github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
-github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
-github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
+github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
+github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
+github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
+github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
+github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
+github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
+github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
+github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
github.com/quic-go/quic-go v0.49.1 h1:e5JXpUyF0f2uFjckQzD8jTghZrOUK1xxDqqZhlwixo0=
github.com/quic-go/quic-go v0.49.1/go.mod h1:s2wDnmCdooUQBmQfpUSTCYBl1/D4FcqbULMMkASvR6s=
github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM=
github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA=
-github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
-github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
+github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
+github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
github.com/rs/cors v1.8.0 h1:P2KMzcFwrPoSjkF1WLRPsp3UMLyql8L4v9hQpVeK5so=
github.com/rs/cors v1.8.0/go.mod h1:EBwu+T5AvHOcXwvZIkQFjUN6s8Czyqw12GL/Y0tUyRM=
github.com/rs/xid v1.3.0 h1:6NjYksEUlhurdVehpc7S7dk6DAmcKv8V9gG0FsVN2U4=
github.com/rs/xid v1.3.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
+github.com/russellhaering/goxmldsig v1.5.0 h1:AU2UkkYIUOTyZRbe08XMThaOCelArgvNfYapcmSjBNw=
+github.com/russellhaering/goxmldsig v1.5.0/go.mod h1:x98CjQNFJcWfMxeOrMnMKg70lvDP6tE0nTaeUnjXDmk=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/rymdport/portal v0.4.2 h1:7jKRSemwlTyVHHrTGgQg7gmNPJs88xkbKcIL3NlcmSU=
github.com/rymdport/portal v0.4.2/go.mod h1:kFF4jslnJ8pD5uCi17brj/ODlfIidOxlgUDTO5ncnC4=
@@ -473,21 +512,26 @@ github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFt
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU=
github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k=
+github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
+github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 h1:JIAuq3EEf9cgbU6AtGPK4CTG3Zf6CKMNqf0MHTggAUA=
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966/go.mod h1:sUM3LWHvSMaG192sy56D9F7CNvL7jUJVXoqM1QKLnog=
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8=
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E=
-github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
-github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
-github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
-github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
+github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w=
+github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
+github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s=
+github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4XaB0=
+github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY=
+github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c h1:km8GpoQut05eY3GiYWEedbTT0qnSxrCjsVbb7yKY1KE=
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c/go.mod h1:cNQ3dwVJtS5Hmnjxy6AgTPd0Inb3pW05ftPSX7NZO7Q=
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef h1:Ch6Q+AZUxDBCVqdkI8FSpFyZDtCVBc2VmejdNrm5rRQ=
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef/go.mod h1:nXTWP6+gD5+LUJ8krVhhoeHjvHTutPxMYl5SvkcnJNE=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
@@ -499,7 +543,6 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
-github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
@@ -553,30 +596,28 @@ github.com/zeebo/blake3 v0.2.3 h1:TFoLXsjeXqRNFxSbk35Dk4YtszE/MQQGK10BH4ptoTg=
github.com/zeebo/blake3 v0.2.3/go.mod h1:mjJjZpnsyIVtVgTOSpJ9vmRE4wgDeyt2HU3qXvvKCaQ=
github.com/zeebo/pcg v1.0.1 h1:lyqfGeWiv4ahac6ttHs+I5hwtH/+1mrhlCtVNQM2kHo=
github.com/zeebo/pcg v1.0.1/go.mod h1:09F0S9iiKrwn9rlI5yjLkmrug154/YRW6KnnXVDM/l4=
-go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
-go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo=
-go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
-go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
-go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 h1:4Pp6oUg3+e/6M4C0A/3kJ2VYa++dsWVTtGgLVj5xtHg=
-go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0/go.mod h1:Mjt1i1INqiaoZOMGR1RIUJN+i3ChKoFRqzrRQhlkbs0=
-go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 h1:Xs2Ncz0gNihqu9iosIZ5SkBbWo5T8JhhLJFMQL1qmLI=
-go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0/go.mod h1:vy+2G/6NvVMpwGX/NyLqcC41fxepnuKHk16E6IZUcJc=
-go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ=
-go.opentelemetry.io/otel v1.35.0/go.mod h1:UEqy8Zp11hpkUrL73gSlELM0DupHoiq72dR+Zqel/+Y=
+go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
+go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
+go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0 h1:q4XOmH/0opmeuJtPsbFNivyl7bCt7yRBbeEm2sC/XtQ=
+go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0/go.mod h1:snMWehoOh2wsEwnvvwtDyFCxVeDAODenXHtn5vzrKjo=
+go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 h1:F7Jx+6hwnZ41NSFTO5q4LYDtJRXBf2PD0rNBkeB/lus=
+go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0/go.mod h1:UHB22Z8QsdRDrnAtX4PntOl36ajSxcdUMt1sF7Y6E7Q=
+go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
+go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0/go.mod h1:IPtUMKL4O3tH5y+iXVyAXqpAwMuzC1IrxVS81rummfE=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU=
go.opentelemetry.io/otel/exporters/prometheus v0.48.0 h1:sBQe3VNGUjY9IKWQC6z2lNqa5iGbDSxhs60ABwK4y0s=
go.opentelemetry.io/otel/exporters/prometheus v0.48.0/go.mod h1:DtrbMzoZWwQHyrQmCfLam5DZbnmorsGbOtTbYHycU5o=
-go.opentelemetry.io/otel/metric v1.35.0 h1:0znxYu2SNyuMSQT4Y9WDWej0VpcsxkuklLa4/siN90M=
-go.opentelemetry.io/otel/metric v1.35.0/go.mod h1:nKVFgxBZ2fReX6IlyW28MgZojkoAkJGaE8CpgeAU3oE=
-go.opentelemetry.io/otel/sdk v1.35.0 h1:iPctf8iprVySXSKJffSS79eOjl9pvxV9ZqOWT0QejKY=
-go.opentelemetry.io/otel/sdk v1.35.0/go.mod h1:+ga1bZliga3DxJ3CQGg3updiaAJoNECOgJREo9KHGQg=
-go.opentelemetry.io/otel/sdk/metric v1.35.0 h1:1RriWBmCKgkeHEhM7a2uMjMUfP7MsOF5JpUCaEqEI9o=
-go.opentelemetry.io/otel/sdk/metric v1.35.0/go.mod h1:is6XYCUMpcKi+ZsOvfluY5YstFnhW0BidkR+gL+qN+w=
-go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt/xgMs=
-go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc=
+go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
+go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
+go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E=
+go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg=
+go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM=
+go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA=
+go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
+go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I=
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
@@ -587,6 +628,8 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
+go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
+go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
goauthentik.io/api/v3 v3.2023051.3 h1:NebAhD/TeTWNo/9X3/Uj+rM5fG1HaiLOlKTNLQv9Qq4=
goauthentik.io/api/v3 v3.2023051.3/go.mod h1:nYECml4jGbp/541hj8GcylKQG1gVBsKppHy4+7G8u4U=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
@@ -600,16 +643,12 @@ golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1m
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
-golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
-golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
-golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
+golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
+golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
golang.org/x/image v0.33.0 h1:LXRZRnv1+zGd5XBUVRFmYEphyyKJjQjCRiOuAP3sZfQ=
golang.org/x/image v0.33.0/go.mod h1:DD3OsTYT9chzuzTQt+zMcOlBHgfoKQb1gry8p76Y1sc=
-golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
-golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
-golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mobile v0.0.0-20251113184115-a159579294ab h1:Iqyc+2zr7aGyLuEadIm0KRJP0Wwt+fhlXLa51Fxf1+Q=
golang.org/x/mobile v0.0.0-20251113184115-a159579294ab/go.mod h1:Eq3Nh/5pFSWug2ohiudJ1iyU59SO78QFuh4qTTN++I0=
@@ -624,18 +663,13 @@ golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
-golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
-golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
-golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
-golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
-golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
@@ -649,12 +683,10 @@ golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
-golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE=
-golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
-golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
+golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw=
+golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -665,9 +697,8 @@ golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
-golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
-golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
-golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
+golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -703,8 +734,8 @@ golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
-golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
-golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
+golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
+golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
@@ -717,8 +748,8 @@ golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
-golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
-golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=
+golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
+golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -730,15 +761,11 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
-golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
-golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
-golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
-golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
+golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
+golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
+golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
+golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
-golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
-golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
-golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
-golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
@@ -761,42 +788,33 @@ golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvY
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
-google.golang.org/api v0.177.0 h1:8a0p/BbPa65GlqGWtUKxot4p0TV8OGOfyTjtmkXNXmk=
-google.golang.org/api v0.177.0/go.mod h1:srbhue4MLjkjbkux5p3dw/ocYOSZTaIEvf7bCOnFQDw=
-google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
-google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
+gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
+gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
+google.golang.org/api v0.257.0 h1:8Y0lzvHlZps53PEaw+G29SsQIkuKrumGWs9puiexNAA=
+google.golang.org/api v0.257.0/go.mod h1:4eJrr+vbVaZSqs7vovFd1Jb/A6ml6iw2e6FBYf3GAO4=
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
-google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
-google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
-google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
-google.golang.org/genproto v0.0.0-20240123012728-ef4313101c80 h1:KAeGQVN3M9nD0/bQXnr/ClcEMJ968gUXJQ9pwfSynuQ=
-google.golang.org/genproto/googleapis/api v0.0.0-20250324211829-b45e905df463 h1:hE3bRWtU6uceqlh4fhrSnUyjKHMKB9KrTLLG+bc0ddM=
-google.golang.org/genproto/googleapis/api v0.0.0-20250324211829-b45e905df463/go.mod h1:U90ffi8eUL9MwPcrJylN5+Mk2v3vuPDptd5yyNUiRR8=
-google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 h1:pFyd6EwwL2TqFf8emdthzeX+gZE1ElRq3iM8pui4KBY=
-google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A=
-google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
-google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
-google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
-google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
-google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc=
-google.golang.org/grpc v1.73.0 h1:VIWSmpI2MegBtTuFt5/JWy2oXxtjJ/e89Z70ImfD2ok=
-google.golang.org/grpc v1.73.0/go.mod h1:50sbHOUqWoCQGI8V2HQLJM0B+LMlIUjNSZmow7EVBQc=
+google.golang.org/genproto v0.0.0-20250603155806-513f23925822 h1:rHWScKit0gvAPuOnu87KpaYtjK5zBMLcULh7gxkCXu4=
+google.golang.org/genproto v0.0.0-20250603155806-513f23925822/go.mod h1:HubltRL7rMh0LfnQPkMH4NPDFEWp0jw3vixw7jEM53s=
+google.golang.org/genproto/googleapis/api v0.0.0-20251022142026-3a174f9686a8 h1:mepRgnBZa07I4TRuomDE4sTIYieg/osKmzIf4USdWS4=
+google.golang.org/genproto/googleapis/api v0.0.0-20251022142026-3a174f9686a8/go.mod h1:fDMmzKV90WSg1NbozdqrE64fkuTv6mlq2zxo9ad+3yo=
+google.golang.org/genproto/googleapis/rpc v0.0.0-20251124214823-79d6a2a48846 h1:Wgl1rcDNThT+Zn47YyCXOXyX/COgMTIdhJ717F0l4xk=
+google.golang.org/genproto/googleapis/rpc v0.0.0-20251124214823-79d6a2a48846/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk=
+google.golang.org/grpc v1.77.0 h1:wVVY6/8cGA6vvffn+wWK5ToddbgdU3d8MNENr4evgXM=
+google.golang.org/grpc v1.77.0/go.mod h1:z0BY1iVj0q8E1uSQCjL9cppRj+gnZjzDnzV0dHhrNig=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE=
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
-google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
-google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
-google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
-google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
-google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
+google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
+google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
@@ -832,5 +850,3 @@ gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY=
gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 h1:qDCwdCWECGnwQSQC01Dpnp09fRHxJs9PbktotUqG+hs=
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1/go.mod h1:8hmigyCdYtw5xJGfQDJzSH5Ju8XEIDBnpyi8+O6GRt8=
-honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
-honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
diff --git a/idp/dex/config.go b/idp/dex/config.go
new file mode 100644
index 000000000..57f832406
--- /dev/null
+++ b/idp/dex/config.go
@@ -0,0 +1,301 @@
+package dex
+
+import (
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "log/slog"
+ "os"
+ "time"
+
+ "golang.org/x/crypto/bcrypt"
+ "gopkg.in/yaml.v3"
+
+ "github.com/dexidp/dex/server"
+ "github.com/dexidp/dex/storage"
+ "github.com/dexidp/dex/storage/sql"
+
+ "github.com/netbirdio/netbird/idp/dex/web"
+)
+
+// parseDuration parses a duration string (e.g., "6h", "24h", "168h").
+func parseDuration(s string) (time.Duration, error) {
+ return time.ParseDuration(s)
+}
+
+// YAMLConfig represents the YAML configuration file format (mirrors dex's config format)
+type YAMLConfig struct {
+ Issuer string `yaml:"issuer" json:"issuer"`
+ Storage Storage `yaml:"storage" json:"storage"`
+ Web Web `yaml:"web" json:"web"`
+ GRPC GRPC `yaml:"grpc" json:"grpc"`
+ OAuth2 OAuth2 `yaml:"oauth2" json:"oauth2"`
+ Expiry Expiry `yaml:"expiry" json:"expiry"`
+ Logger Logger `yaml:"logger" json:"logger"`
+ Frontend Frontend `yaml:"frontend" json:"frontend"`
+
+ // StaticConnectors are user defined connectors specified in the config file
+ StaticConnectors []Connector `yaml:"connectors" json:"connectors"`
+
+ // StaticClients cause the server to use this list of clients rather than
+ // querying the storage. Write operations, like creating a client, will fail.
+ StaticClients []storage.Client `yaml:"staticClients" json:"staticClients"`
+
+ // If enabled, the server will maintain a list of passwords which can be used
+ // to identify a user.
+ EnablePasswordDB bool `yaml:"enablePasswordDB" json:"enablePasswordDB"`
+
+ // StaticPasswords cause the server use this list of passwords rather than
+ // querying the storage.
+ StaticPasswords []Password `yaml:"staticPasswords" json:"staticPasswords"`
+}
+
+// Web is the config format for the HTTP server.
+type Web struct {
+ HTTP string `yaml:"http" json:"http"`
+ HTTPS string `yaml:"https" json:"https"`
+ AllowedOrigins []string `yaml:"allowedOrigins" json:"allowedOrigins"`
+ AllowedHeaders []string `yaml:"allowedHeaders" json:"allowedHeaders"`
+}
+
+// GRPC is the config for the gRPC API.
+type GRPC struct {
+ Addr string `yaml:"addr" json:"addr"`
+ TLSCert string `yaml:"tlsCert" json:"tlsCert"`
+ TLSKey string `yaml:"tlsKey" json:"tlsKey"`
+ TLSClientCA string `yaml:"tlsClientCA" json:"tlsClientCA"`
+}
+
+// OAuth2 describes enabled OAuth2 extensions.
+type OAuth2 struct {
+ SkipApprovalScreen bool `yaml:"skipApprovalScreen" json:"skipApprovalScreen"`
+ AlwaysShowLoginScreen bool `yaml:"alwaysShowLoginScreen" json:"alwaysShowLoginScreen"`
+ PasswordConnector string `yaml:"passwordConnector" json:"passwordConnector"`
+ ResponseTypes []string `yaml:"responseTypes" json:"responseTypes"`
+ GrantTypes []string `yaml:"grantTypes" json:"grantTypes"`
+}
+
+// Expiry holds configuration for the validity period of components.
+type Expiry struct {
+ SigningKeys string `yaml:"signingKeys" json:"signingKeys"`
+ IDTokens string `yaml:"idTokens" json:"idTokens"`
+ AuthRequests string `yaml:"authRequests" json:"authRequests"`
+ DeviceRequests string `yaml:"deviceRequests" json:"deviceRequests"`
+ RefreshTokens RefreshTokensExpiry `yaml:"refreshTokens" json:"refreshTokens"`
+}
+
+// RefreshTokensExpiry holds configuration for refresh token expiry.
+type RefreshTokensExpiry struct {
+ ReuseInterval string `yaml:"reuseInterval" json:"reuseInterval"`
+ ValidIfNotUsedFor string `yaml:"validIfNotUsedFor" json:"validIfNotUsedFor"`
+ AbsoluteLifetime string `yaml:"absoluteLifetime" json:"absoluteLifetime"`
+ DisableRotation bool `yaml:"disableRotation" json:"disableRotation"`
+}
+
+// Logger holds configuration required to customize logging.
+type Logger struct {
+ Level string `yaml:"level" json:"level"`
+ Format string `yaml:"format" json:"format"`
+}
+
+// Frontend holds the server's frontend templates and assets config.
+type Frontend struct {
+ Dir string `yaml:"dir" json:"dir"`
+ Theme string `yaml:"theme" json:"theme"`
+ Issuer string `yaml:"issuer" json:"issuer"`
+ LogoURL string `yaml:"logoURL" json:"logoURL"`
+ Extra map[string]string `yaml:"extra" json:"extra"`
+}
+
+// Storage holds app's storage configuration.
+type Storage struct {
+ Type string `yaml:"type" json:"type"`
+ Config map[string]interface{} `yaml:"config" json:"config"`
+}
+
+// Password represents a static user configuration
+type Password storage.Password
+
+func (p *Password) UnmarshalYAML(node *yaml.Node) error {
+ var data struct {
+ Email string `yaml:"email"`
+ Username string `yaml:"username"`
+ UserID string `yaml:"userID"`
+ Hash string `yaml:"hash"`
+ HashFromEnv string `yaml:"hashFromEnv"`
+ }
+ if err := node.Decode(&data); err != nil {
+ return err
+ }
+ *p = Password(storage.Password{
+ Email: data.Email,
+ Username: data.Username,
+ UserID: data.UserID,
+ })
+ if len(data.Hash) == 0 && len(data.HashFromEnv) > 0 {
+ data.Hash = os.Getenv(data.HashFromEnv)
+ }
+ if len(data.Hash) == 0 {
+ return fmt.Errorf("no password hash provided for user %s", data.Email)
+ }
+
+ // If this value is a valid bcrypt, use it.
+ _, bcryptErr := bcrypt.Cost([]byte(data.Hash))
+ if bcryptErr == nil {
+ p.Hash = []byte(data.Hash)
+ return nil
+ }
+
+ // For backwards compatibility try to base64 decode this value.
+ hashBytes, err := base64.StdEncoding.DecodeString(data.Hash)
+ if err != nil {
+ return fmt.Errorf("malformed bcrypt hash: %v", bcryptErr)
+ }
+ if _, err := bcrypt.Cost(hashBytes); err != nil {
+ return fmt.Errorf("malformed bcrypt hash: %v", err)
+ }
+ p.Hash = hashBytes
+ return nil
+}
+
+// Connector is a connector configuration that can unmarshal YAML dynamically.
+type Connector struct {
+ Type string `yaml:"type" json:"type"`
+ Name string `yaml:"name" json:"name"`
+ ID string `yaml:"id" json:"id"`
+ Config map[string]interface{} `yaml:"config" json:"config"`
+}
+
+// ToStorageConnector converts a Connector to storage.Connector type.
+func (c *Connector) ToStorageConnector() (storage.Connector, error) {
+ data, err := json.Marshal(c.Config)
+ if err != nil {
+ return storage.Connector{}, fmt.Errorf("failed to marshal connector config: %v", err)
+ }
+
+ return storage.Connector{
+ ID: c.ID,
+ Type: c.Type,
+ Name: c.Name,
+ Config: data,
+ }, nil
+}
+
+// StorageConfig is a configuration that can create a storage.
+type StorageConfig interface {
+ Open(logger *slog.Logger) (storage.Storage, error)
+}
+
+// OpenStorage opens a storage based on the config
+func (s *Storage) OpenStorage(logger *slog.Logger) (storage.Storage, error) {
+ switch s.Type {
+ case "sqlite3":
+ file, _ := s.Config["file"].(string)
+ if file == "" {
+ return nil, fmt.Errorf("sqlite3 storage requires 'file' config")
+ }
+ return (&sql.SQLite3{File: file}).Open(logger)
+ default:
+ return nil, fmt.Errorf("unsupported storage type: %s", s.Type)
+ }
+}
+
+// Validate validates the configuration
+func (c *YAMLConfig) Validate() error {
+ if c.Issuer == "" {
+ return fmt.Errorf("no issuer specified in config file")
+ }
+ if c.Storage.Type == "" {
+ return fmt.Errorf("no storage type specified in config file")
+ }
+ if c.Web.HTTP == "" && c.Web.HTTPS == "" {
+ return fmt.Errorf("must supply a HTTP/HTTPS address to listen on")
+ }
+ if !c.EnablePasswordDB && len(c.StaticPasswords) != 0 {
+ return fmt.Errorf("cannot specify static passwords without enabling password db")
+ }
+ return nil
+}
+
+// ToServerConfig converts YAMLConfig to dex server.Config
+func (c *YAMLConfig) ToServerConfig(stor storage.Storage, logger *slog.Logger) server.Config {
+ cfg := server.Config{
+ Issuer: c.Issuer,
+ Storage: stor,
+ Logger: logger,
+ SkipApprovalScreen: c.OAuth2.SkipApprovalScreen,
+ AllowedOrigins: c.Web.AllowedOrigins,
+ AllowedHeaders: c.Web.AllowedHeaders,
+ Web: server.WebConfig{
+ Issuer: c.Frontend.Issuer,
+ LogoURL: c.Frontend.LogoURL,
+ Theme: c.Frontend.Theme,
+ Dir: c.Frontend.Dir,
+ Extra: c.Frontend.Extra,
+ },
+ }
+
+ // Use embedded NetBird-styled templates if no custom dir specified
+ if c.Frontend.Dir == "" {
+ cfg.Web.WebFS = web.FS()
+ }
+
+ if len(c.OAuth2.ResponseTypes) > 0 {
+ cfg.SupportedResponseTypes = c.OAuth2.ResponseTypes
+ }
+
+ // Apply expiry settings
+ if c.Expiry.SigningKeys != "" {
+ if d, err := parseDuration(c.Expiry.SigningKeys); err == nil {
+ cfg.RotateKeysAfter = d
+ }
+ }
+ if c.Expiry.IDTokens != "" {
+ if d, err := parseDuration(c.Expiry.IDTokens); err == nil {
+ cfg.IDTokensValidFor = d
+ }
+ }
+ if c.Expiry.AuthRequests != "" {
+ if d, err := parseDuration(c.Expiry.AuthRequests); err == nil {
+ cfg.AuthRequestsValidFor = d
+ }
+ }
+ if c.Expiry.DeviceRequests != "" {
+ if d, err := parseDuration(c.Expiry.DeviceRequests); err == nil {
+ cfg.DeviceRequestsValidFor = d
+ }
+ }
+
+ return cfg
+}
+
+// GetRefreshTokenPolicy creates a RefreshTokenPolicy from the expiry config.
+// This should be called after ToServerConfig and the policy set on the config.
+func (c *YAMLConfig) GetRefreshTokenPolicy(logger *slog.Logger) (*server.RefreshTokenPolicy, error) {
+ return server.NewRefreshTokenPolicy(
+ logger,
+ c.Expiry.RefreshTokens.DisableRotation,
+ c.Expiry.RefreshTokens.ValidIfNotUsedFor,
+ c.Expiry.RefreshTokens.AbsoluteLifetime,
+ c.Expiry.RefreshTokens.ReuseInterval,
+ )
+}
+
+// LoadConfig loads configuration from a YAML file
+func LoadConfig(path string) (*YAMLConfig, error) {
+ data, err := os.ReadFile(path)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read config file: %w", err)
+ }
+
+ var cfg YAMLConfig
+ if err := yaml.Unmarshal(data, &cfg); err != nil {
+ return nil, fmt.Errorf("failed to parse config file: %w", err)
+ }
+
+ if err := cfg.Validate(); err != nil {
+ return nil, err
+ }
+
+ return &cfg, nil
+}
diff --git a/idp/dex/provider.go b/idp/dex/provider.go
new file mode 100644
index 000000000..09713a226
--- /dev/null
+++ b/idp/dex/provider.go
@@ -0,0 +1,934 @@
+// Package dex provides an embedded Dex OIDC identity provider.
+package dex
+
+import (
+ "context"
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "log/slog"
+ "net"
+ "net/http"
+ "os"
+ "path/filepath"
+ "strings"
+ "sync"
+ "time"
+
+ dexapi "github.com/dexidp/dex/api/v2"
+ "github.com/dexidp/dex/server"
+ "github.com/dexidp/dex/storage"
+ "github.com/dexidp/dex/storage/sql"
+ "github.com/google/uuid"
+ "github.com/prometheus/client_golang/prometheus"
+ "golang.org/x/crypto/bcrypt"
+ "google.golang.org/grpc"
+)
+
+// Config matches what management/internals/server/server.go expects
+type Config struct {
+ Issuer string
+ Port int
+ DataDir string
+ DevMode bool
+
+ // GRPCAddr is the address for the gRPC API (e.g., ":5557"). Empty disables gRPC.
+ GRPCAddr string
+}
+
+// Provider wraps a Dex server
+type Provider struct {
+ config *Config
+ yamlConfig *YAMLConfig
+ dexServer *server.Server
+ httpServer *http.Server
+ listener net.Listener
+ grpcServer *grpc.Server
+ grpcListener net.Listener
+ storage storage.Storage
+ logger *slog.Logger
+ mu sync.Mutex
+ running bool
+}
+
+// NewProvider creates and initializes the Dex server
+func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
+ if config.Issuer == "" {
+ return nil, fmt.Errorf("issuer is required")
+ }
+ if config.Port <= 0 {
+ return nil, fmt.Errorf("invalid port")
+ }
+ if config.DataDir == "" {
+ return nil, fmt.Errorf("data directory is required")
+ }
+
+ logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
+
+ // Ensure data directory exists
+ if err := os.MkdirAll(config.DataDir, 0700); err != nil {
+ return nil, fmt.Errorf("failed to create data directory: %w", err)
+ }
+
+ // Initialize SQLite storage
+ dbPath := filepath.Join(config.DataDir, "oidc.db")
+ sqliteConfig := &sql.SQLite3{File: dbPath}
+ stor, err := sqliteConfig.Open(logger)
+ if err != nil {
+ return nil, fmt.Errorf("failed to open storage: %w", err)
+ }
+
+ // Ensure a local connector exists (for password authentication)
+ if err := ensureLocalConnector(ctx, stor); err != nil {
+ stor.Close()
+ return nil, fmt.Errorf("failed to ensure local connector: %w", err)
+ }
+
+ // Ensure issuer ends with /oauth2 for proper path mounting
+ issuer := strings.TrimSuffix(config.Issuer, "/")
+ if !strings.HasSuffix(issuer, "/oauth2") {
+ issuer += "/oauth2"
+ }
+
+ // Build refresh token policy (required to avoid nil pointer panics)
+ refreshPolicy, err := server.NewRefreshTokenPolicy(logger, false, "", "", "")
+ if err != nil {
+ stor.Close()
+ return nil, fmt.Errorf("failed to create refresh token policy: %w", err)
+ }
+
+ // Build Dex server config - use Dex's types directly
+ dexConfig := server.Config{
+ Issuer: issuer,
+ Storage: stor,
+ SkipApprovalScreen: true,
+ SupportedResponseTypes: []string{"code"},
+ Logger: logger,
+ PrometheusRegistry: prometheus.NewRegistry(),
+ RotateKeysAfter: 6 * time.Hour,
+ IDTokensValidFor: 24 * time.Hour,
+ RefreshTokenPolicy: refreshPolicy,
+ Web: server.WebConfig{
+ Issuer: "NetBird",
+ },
+ }
+
+ dexSrv, err := server.NewServer(ctx, dexConfig)
+ if err != nil {
+ stor.Close()
+ return nil, fmt.Errorf("failed to create dex server: %w", err)
+ }
+
+ return &Provider{
+ config: config,
+ dexServer: dexSrv,
+ storage: stor,
+ logger: logger,
+ }, nil
+}
+
+// NewProviderFromYAML creates and initializes the Dex server from a YAMLConfig
+func NewProviderFromYAML(ctx context.Context, yamlConfig *YAMLConfig) (*Provider, error) {
+ logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
+
+ stor, err := yamlConfig.Storage.OpenStorage(logger)
+ if err != nil {
+ return nil, fmt.Errorf("failed to open storage: %w", err)
+ }
+
+ if err := initializeStorage(ctx, stor, yamlConfig); err != nil {
+ stor.Close()
+ return nil, err
+ }
+
+ dexConfig := buildDexConfig(yamlConfig, stor, logger)
+ dexConfig.RefreshTokenPolicy, err = yamlConfig.GetRefreshTokenPolicy(logger)
+ if err != nil {
+ stor.Close()
+ return nil, fmt.Errorf("failed to create refresh token policy: %w", err)
+ }
+
+ dexSrv, err := server.NewServer(ctx, dexConfig)
+ if err != nil {
+ stor.Close()
+ return nil, fmt.Errorf("failed to create dex server: %w", err)
+ }
+
+ return &Provider{
+ config: &Config{Issuer: yamlConfig.Issuer, GRPCAddr: yamlConfig.GRPC.Addr},
+ yamlConfig: yamlConfig,
+ dexServer: dexSrv,
+ storage: stor,
+ logger: logger,
+ }, nil
+}
+
+// initializeStorage sets up connectors, passwords, and clients in storage
+func initializeStorage(ctx context.Context, stor storage.Storage, cfg *YAMLConfig) error {
+ if cfg.EnablePasswordDB {
+ if err := ensureLocalConnector(ctx, stor); err != nil {
+ return fmt.Errorf("failed to ensure local connector: %w", err)
+ }
+ }
+ if err := ensureStaticPasswords(ctx, stor, cfg.StaticPasswords); err != nil {
+ return err
+ }
+ if err := ensureStaticClients(ctx, stor, cfg.StaticClients); err != nil {
+ return err
+ }
+ return ensureStaticConnectors(ctx, stor, cfg.StaticConnectors)
+}
+
+// ensureStaticPasswords creates or updates static passwords in storage
+func ensureStaticPasswords(ctx context.Context, stor storage.Storage, passwords []Password) error {
+ for _, pw := range passwords {
+ existing, err := stor.GetPassword(ctx, pw.Email)
+ if errors.Is(err, storage.ErrNotFound) {
+ if err := stor.CreatePassword(ctx, storage.Password(pw)); err != nil {
+ return fmt.Errorf("failed to create password for %s: %w", pw.Email, err)
+ }
+ continue
+ }
+ if err != nil {
+ return fmt.Errorf("failed to get password for %s: %w", pw.Email, err)
+ }
+ if string(existing.Hash) != string(pw.Hash) {
+ if err := stor.UpdatePassword(ctx, pw.Email, func(old storage.Password) (storage.Password, error) {
+ old.Hash = pw.Hash
+ old.Username = pw.Username
+ return old, nil
+ }); err != nil {
+ return fmt.Errorf("failed to update password for %s: %w", pw.Email, err)
+ }
+ }
+ }
+ return nil
+}
+
+// ensureStaticClients creates or updates static clients in storage
+func ensureStaticClients(ctx context.Context, stor storage.Storage, clients []storage.Client) error {
+ for _, client := range clients {
+ _, err := stor.GetClient(ctx, client.ID)
+ if errors.Is(err, storage.ErrNotFound) {
+ if err := stor.CreateClient(ctx, client); err != nil {
+ return fmt.Errorf("failed to create client %s: %w", client.ID, err)
+ }
+ continue
+ }
+ if err != nil {
+ return fmt.Errorf("failed to get client %s: %w", client.ID, err)
+ }
+ if err := stor.UpdateClient(ctx, client.ID, func(old storage.Client) (storage.Client, error) {
+ old.RedirectURIs = client.RedirectURIs
+ old.Name = client.Name
+ old.Public = client.Public
+ return old, nil
+ }); err != nil {
+ return fmt.Errorf("failed to update client %s: %w", client.ID, err)
+ }
+ }
+ return nil
+}
+
+// ensureStaticConnectors creates or updates static connectors in storage
+func ensureStaticConnectors(ctx context.Context, stor storage.Storage, connectors []Connector) error {
+ for _, conn := range connectors {
+ storConn, err := conn.ToStorageConnector()
+ if err != nil {
+ return fmt.Errorf("failed to convert connector %s: %w", conn.ID, err)
+ }
+ _, err = stor.GetConnector(ctx, conn.ID)
+ if errors.Is(err, storage.ErrNotFound) {
+ if err := stor.CreateConnector(ctx, storConn); err != nil {
+ return fmt.Errorf("failed to create connector %s: %w", conn.ID, err)
+ }
+ continue
+ }
+ if err != nil {
+ return fmt.Errorf("failed to get connector %s: %w", conn.ID, err)
+ }
+ if err := stor.UpdateConnector(ctx, conn.ID, func(old storage.Connector) (storage.Connector, error) {
+ old.Name = storConn.Name
+ old.Config = storConn.Config
+ return old, nil
+ }); err != nil {
+ return fmt.Errorf("failed to update connector %s: %w", conn.ID, err)
+ }
+ }
+ return nil
+}
+
+// buildDexConfig creates a server.Config with defaults applied
+func buildDexConfig(yamlConfig *YAMLConfig, stor storage.Storage, logger *slog.Logger) server.Config {
+ cfg := yamlConfig.ToServerConfig(stor, logger)
+ cfg.PrometheusRegistry = prometheus.NewRegistry()
+ if cfg.RotateKeysAfter == 0 {
+ cfg.RotateKeysAfter = 24 * 30 * time.Hour
+ }
+ if cfg.IDTokensValidFor == 0 {
+ cfg.IDTokensValidFor = 24 * time.Hour
+ }
+ if cfg.Web.Issuer == "" {
+ cfg.Web.Issuer = "NetBird"
+ }
+ if len(cfg.SupportedResponseTypes) == 0 {
+ cfg.SupportedResponseTypes = []string{"code"}
+ }
+ return cfg
+}
+
+// Start starts the HTTP server and optionally the gRPC API server
+func (p *Provider) Start(_ context.Context) error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ if p.running {
+ return fmt.Errorf("already running")
+ }
+
+ // Determine listen address from config
+ var addr string
+ if p.yamlConfig != nil {
+ addr = p.yamlConfig.Web.HTTP
+ if addr == "" {
+ addr = p.yamlConfig.Web.HTTPS
+ }
+ } else if p.config != nil && p.config.Port > 0 {
+ addr = fmt.Sprintf(":%d", p.config.Port)
+ }
+ if addr == "" {
+ return fmt.Errorf("no listen address configured")
+ }
+
+ listener, err := net.Listen("tcp", addr)
+ if err != nil {
+ return fmt.Errorf("failed to listen on %s: %w", addr, err)
+ }
+ p.listener = listener
+
+ // Mount Dex at /oauth2/ path for reverse proxy compatibility
+ // Don't strip the prefix - Dex's issuer includes /oauth2 so it expects the full path
+ mux := http.NewServeMux()
+ mux.Handle("/oauth2/", p.dexServer)
+
+ p.httpServer = &http.Server{Handler: mux}
+ p.running = true
+
+ go func() {
+ if err := p.httpServer.Serve(listener); err != nil && err != http.ErrServerClosed {
+ p.logger.Error("http server error", "error", err)
+ }
+ }()
+
+ // Start gRPC API server if configured
+ if p.config.GRPCAddr != "" {
+ if err := p.startGRPCServer(); err != nil {
+ // Clean up HTTP server on failure
+ _ = p.httpServer.Close()
+ _ = p.listener.Close()
+ return fmt.Errorf("failed to start gRPC server: %w", err)
+ }
+ }
+
+ p.logger.Info("HTTP server started", "addr", addr)
+ return nil
+}
+
+// startGRPCServer starts the gRPC API server using Dex's built-in API
+func (p *Provider) startGRPCServer() error {
+ grpcListener, err := net.Listen("tcp", p.config.GRPCAddr)
+ if err != nil {
+ return fmt.Errorf("failed to listen on %s: %w", p.config.GRPCAddr, err)
+ }
+ p.grpcListener = grpcListener
+
+ p.grpcServer = grpc.NewServer()
+ // Use Dex's built-in API server implementation
+ // server.NewAPI(storage, logger, version, dexServer)
+ dexapi.RegisterDexServer(p.grpcServer, server.NewAPI(p.storage, p.logger, "netbird-dex", p.dexServer))
+
+ go func() {
+ if err := p.grpcServer.Serve(grpcListener); err != nil {
+ p.logger.Error("grpc server error", "error", err)
+ }
+ }()
+
+ p.logger.Info("gRPC API server started", "addr", p.config.GRPCAddr)
+ return nil
+}
+
+// Stop gracefully shuts down
+func (p *Provider) Stop(ctx context.Context) error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ if !p.running {
+ return nil
+ }
+
+ var errs []error
+
+ // Stop gRPC server first
+ if p.grpcServer != nil {
+ p.grpcServer.GracefulStop()
+ p.grpcServer = nil
+ }
+ if p.grpcListener != nil {
+ p.grpcListener.Close()
+ p.grpcListener = nil
+ }
+
+ if p.httpServer != nil {
+ if err := p.httpServer.Shutdown(ctx); err != nil {
+ errs = append(errs, err)
+ }
+ }
+
+ // Explicitly close listener as fallback (Shutdown should do this, but be safe)
+ if p.listener != nil {
+ if err := p.listener.Close(); err != nil {
+ // Ignore "use of closed network connection" - expected after Shutdown
+ if !strings.Contains(err.Error(), "use of closed") {
+ errs = append(errs, err)
+ }
+ }
+ p.listener = nil
+ }
+
+ if p.storage != nil {
+ if err := p.storage.Close(); err != nil {
+ errs = append(errs, err)
+ }
+ }
+
+ p.httpServer = nil
+ p.running = false
+
+ if len(errs) > 0 {
+ return fmt.Errorf("shutdown errors: %v", errs)
+ }
+ return nil
+}
+
+// EnsureDefaultClients creates dashboard and CLI OAuth clients
+// Uses Dex's storage.Client directly - no custom wrappers
+func (p *Provider) EnsureDefaultClients(ctx context.Context, dashboardURIs, cliURIs []string) error {
+ clients := []storage.Client{
+ {
+ ID: "netbird-dashboard",
+ Name: "NetBird Dashboard",
+ RedirectURIs: dashboardURIs,
+ Public: true,
+ },
+ {
+ ID: "netbird-cli",
+ Name: "NetBird CLI",
+ RedirectURIs: cliURIs,
+ Public: true,
+ },
+ }
+
+ for _, client := range clients {
+ _, err := p.storage.GetClient(ctx, client.ID)
+ if err == storage.ErrNotFound {
+ if err := p.storage.CreateClient(ctx, client); err != nil {
+ return fmt.Errorf("failed to create client %s: %w", client.ID, err)
+ }
+ continue
+ }
+ if err != nil {
+ return fmt.Errorf("failed to get client %s: %w", client.ID, err)
+ }
+ // Update if exists
+ if err := p.storage.UpdateClient(ctx, client.ID, func(old storage.Client) (storage.Client, error) {
+ old.RedirectURIs = client.RedirectURIs
+ return old, nil
+ }); err != nil {
+ return fmt.Errorf("failed to update client %s: %w", client.ID, err)
+ }
+ }
+
+ p.logger.Info("default OIDC clients ensured")
+ return nil
+}
+
+// Storage returns the underlying Dex storage for direct access
+// Users can use storage.Client, storage.Password, storage.Connector directly
+func (p *Provider) Storage() storage.Storage {
+ return p.storage
+}
+
+// Handler returns the Dex server as an http.Handler for embedding in another server.
+// The handler expects requests with path prefix "/oauth2/".
+func (p *Provider) Handler() http.Handler {
+ return p.dexServer
+}
+
+// CreateUser creates a new user with the given email, username, and password.
+// Returns the encoded user ID in Dex's format (base64-encoded protobuf with connector ID).
+func (p *Provider) CreateUser(ctx context.Context, email, username, password string) (string, error) {
+ hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
+ if err != nil {
+ return "", fmt.Errorf("failed to hash password: %w", err)
+ }
+
+ userID := uuid.New().String()
+ err = p.storage.CreatePassword(ctx, storage.Password{
+ Email: email,
+ Username: username,
+ UserID: userID,
+ Hash: hash,
+ })
+ if err != nil {
+ return "", err
+ }
+
+ // Encode the user ID in Dex's format: base64(protobuf{user_id, connector_id})
+ // This matches the format Dex uses in JWT tokens
+ encodedID := EncodeDexUserID(userID, "local")
+ return encodedID, nil
+}
+
+// EncodeDexUserID encodes user ID and connector ID into Dex's base64-encoded protobuf format.
+// Dex uses this format for the 'sub' claim in JWT tokens.
+// Format: base64(protobuf message with field 1 = user_id, field 2 = connector_id)
+func EncodeDexUserID(userID, connectorID string) string {
+ // Manually encode protobuf: field 1 (user_id) and field 2 (connector_id)
+ // Wire type 2 (length-delimited) for strings
+ var buf []byte
+
+ // Field 1: user_id (tag = 0x0a = field 1, wire type 2)
+ buf = append(buf, 0x0a)
+ buf = append(buf, byte(len(userID)))
+ buf = append(buf, []byte(userID)...)
+
+ // Field 2: connector_id (tag = 0x12 = field 2, wire type 2)
+ buf = append(buf, 0x12)
+ buf = append(buf, byte(len(connectorID)))
+ buf = append(buf, []byte(connectorID)...)
+
+ return base64.RawStdEncoding.EncodeToString(buf)
+}
+
+// DecodeDexUserID decodes Dex's base64-encoded user ID back to the raw user ID and connector ID.
+func DecodeDexUserID(encodedID string) (userID, connectorID string, err error) {
+ // Try RawStdEncoding first, then StdEncoding (with padding)
+ buf, err := base64.RawStdEncoding.DecodeString(encodedID)
+ if err != nil {
+ buf, err = base64.StdEncoding.DecodeString(encodedID)
+ if err != nil {
+ return "", "", fmt.Errorf("failed to decode base64: %w", err)
+ }
+ }
+
+ // Parse protobuf manually
+ i := 0
+ for i < len(buf) {
+ if i >= len(buf) {
+ break
+ }
+ tag := buf[i]
+ i++
+
+ fieldNum := tag >> 3
+ wireType := tag & 0x07
+
+ if wireType != 2 { // We only expect length-delimited strings
+ return "", "", fmt.Errorf("unexpected wire type %d", wireType)
+ }
+
+ if i >= len(buf) {
+ return "", "", fmt.Errorf("truncated message")
+ }
+ length := int(buf[i])
+ i++
+
+ if i+length > len(buf) {
+ return "", "", fmt.Errorf("truncated string field")
+ }
+ value := string(buf[i : i+length])
+ i += length
+
+ switch fieldNum {
+ case 1:
+ userID = value
+ case 2:
+ connectorID = value
+ }
+ }
+
+ return userID, connectorID, nil
+}
+
+// GetUser returns a user by email
+func (p *Provider) GetUser(ctx context.Context, email string) (storage.Password, error) {
+ return p.storage.GetPassword(ctx, email)
+}
+
+// GetUserByID returns a user by user ID.
+// The userID can be either an encoded Dex ID (base64 protobuf) or a raw UUID.
+// Note: This requires iterating through all users since dex storage doesn't index by userID.
+func (p *Provider) GetUserByID(ctx context.Context, userID string) (storage.Password, error) {
+ // Try to decode the user ID in case it's encoded
+ rawUserID, _, err := DecodeDexUserID(userID)
+ if err != nil {
+ // If decoding fails, assume it's already a raw UUID
+ rawUserID = userID
+ }
+
+ users, err := p.storage.ListPasswords(ctx)
+ if err != nil {
+ return storage.Password{}, fmt.Errorf("failed to list users: %w", err)
+ }
+ for _, user := range users {
+ if user.UserID == rawUserID {
+ return user, nil
+ }
+ }
+ return storage.Password{}, storage.ErrNotFound
+}
+
+// DeleteUser removes a user by email
+func (p *Provider) DeleteUser(ctx context.Context, email string) error {
+ return p.storage.DeletePassword(ctx, email)
+}
+
+// ListUsers returns all users
+func (p *Provider) ListUsers(ctx context.Context) ([]storage.Password, error) {
+ return p.storage.ListPasswords(ctx)
+}
+
+// ensureLocalConnector creates a local (password) connector if none exists
+func ensureLocalConnector(ctx context.Context, stor storage.Storage) error {
+ connectors, err := stor.ListConnectors(ctx)
+ if err != nil {
+ return fmt.Errorf("failed to list connectors: %w", err)
+ }
+
+ // If any connector exists, we're good
+ if len(connectors) > 0 {
+ return nil
+ }
+
+ // Create a local connector for password authentication
+ localConnector := storage.Connector{
+ ID: "local",
+ Type: "local",
+ Name: "Email",
+ }
+
+ if err := stor.CreateConnector(ctx, localConnector); err != nil {
+ return fmt.Errorf("failed to create local connector: %w", err)
+ }
+
+ return nil
+}
+
+// ConnectorConfig represents the configuration for an identity provider connector
+type ConnectorConfig struct {
+ // ID is the unique identifier for the connector
+ ID string
+ // Name is a human-readable name for the connector
+ Name string
+ // Type is the connector type (oidc, google, microsoft)
+ Type string
+ // Issuer is the OIDC issuer URL (for OIDC-based connectors)
+ Issuer string
+ // ClientID is the OAuth2 client ID
+ ClientID string
+ // ClientSecret is the OAuth2 client secret
+ ClientSecret string
+ // RedirectURI is the OAuth2 redirect URI
+ RedirectURI string
+}
+
+// CreateConnector creates a new connector in Dex storage.
+// It maps the connector config to the appropriate Dex connector type and configuration.
+func (p *Provider) CreateConnector(ctx context.Context, cfg *ConnectorConfig) (*ConnectorConfig, error) {
+ // Fill in the redirect URI if not provided
+ if cfg.RedirectURI == "" {
+ cfg.RedirectURI = p.GetRedirectURI()
+ }
+
+ storageConn, err := p.buildStorageConnector(cfg)
+ if err != nil {
+ return nil, fmt.Errorf("failed to build connector: %w", err)
+ }
+
+ if err := p.storage.CreateConnector(ctx, storageConn); err != nil {
+ return nil, fmt.Errorf("failed to create connector: %w", err)
+ }
+
+ p.logger.Info("connector created", "id", cfg.ID, "type", cfg.Type)
+ return cfg, nil
+}
+
+// GetConnector retrieves a connector by ID from Dex storage.
+func (p *Provider) GetConnector(ctx context.Context, id string) (*ConnectorConfig, error) {
+ conn, err := p.storage.GetConnector(ctx, id)
+ if err != nil {
+ if err == storage.ErrNotFound {
+ return nil, err
+ }
+ return nil, fmt.Errorf("failed to get connector: %w", err)
+ }
+
+ return p.parseStorageConnector(conn)
+}
+
+// ListConnectors returns all connectors from Dex storage (excluding the local connector).
+func (p *Provider) ListConnectors(ctx context.Context) ([]*ConnectorConfig, error) {
+ connectors, err := p.storage.ListConnectors(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("failed to list connectors: %w", err)
+ }
+
+ result := make([]*ConnectorConfig, 0, len(connectors))
+ for _, conn := range connectors {
+ // Skip the local password connector
+ if conn.ID == "local" && conn.Type == "local" {
+ continue
+ }
+
+ cfg, err := p.parseStorageConnector(conn)
+ if err != nil {
+ p.logger.Warn("failed to parse connector", "id", conn.ID, "error", err)
+ continue
+ }
+ result = append(result, cfg)
+ }
+
+ return result, nil
+}
+
+// UpdateConnector updates an existing connector in Dex storage.
+func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) error {
+ storageConn, err := p.buildStorageConnector(cfg)
+ if err != nil {
+ return fmt.Errorf("failed to build connector: %w", err)
+ }
+
+ if err := p.storage.UpdateConnector(ctx, cfg.ID, func(old storage.Connector) (storage.Connector, error) {
+ return storageConn, nil
+ }); err != nil {
+ return fmt.Errorf("failed to update connector: %w", err)
+ }
+
+ p.logger.Info("connector updated", "id", cfg.ID, "type", cfg.Type)
+ return nil
+}
+
+// DeleteConnector removes a connector from Dex storage.
+func (p *Provider) DeleteConnector(ctx context.Context, id string) error {
+ // Prevent deletion of the local connector
+ if id == "local" {
+ return fmt.Errorf("cannot delete the local password connector")
+ }
+
+ if err := p.storage.DeleteConnector(ctx, id); err != nil {
+ return fmt.Errorf("failed to delete connector: %w", err)
+ }
+
+ p.logger.Info("connector deleted", "id", id)
+ return nil
+}
+
+// buildStorageConnector creates a storage.Connector from ConnectorConfig.
+// It handles the type-specific configuration for each connector type.
+func (p *Provider) buildStorageConnector(cfg *ConnectorConfig) (storage.Connector, error) {
+ redirectURI := p.resolveRedirectURI(cfg.RedirectURI)
+
+ var dexType string
+ var configData []byte
+ var err error
+
+ switch cfg.Type {
+ case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak":
+ dexType = "oidc"
+ configData, err = buildOIDCConnectorConfig(cfg, redirectURI)
+ case "google":
+ dexType = "google"
+ configData, err = buildOAuth2ConnectorConfig(cfg, redirectURI)
+ case "microsoft":
+ dexType = "microsoft"
+ configData, err = buildOAuth2ConnectorConfig(cfg, redirectURI)
+ default:
+ return storage.Connector{}, fmt.Errorf("unsupported connector type: %s", cfg.Type)
+ }
+ if err != nil {
+ return storage.Connector{}, err
+ }
+
+ return storage.Connector{ID: cfg.ID, Type: dexType, Name: cfg.Name, Config: configData}, nil
+}
+
+// resolveRedirectURI returns the redirect URI, using a default if not provided
+func (p *Provider) resolveRedirectURI(redirectURI string) string {
+ if redirectURI != "" || p.config == nil {
+ return redirectURI
+ }
+ issuer := strings.TrimSuffix(p.config.Issuer, "/")
+ if !strings.HasSuffix(issuer, "/oauth2") {
+ issuer += "/oauth2"
+ }
+ return issuer + "/callback"
+}
+
+// buildOIDCConnectorConfig creates config for OIDC-based connectors
+func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte, error) {
+ oidcConfig := map[string]interface{}{
+ "issuer": cfg.Issuer,
+ "clientID": cfg.ClientID,
+ "clientSecret": cfg.ClientSecret,
+ "redirectURI": redirectURI,
+ "scopes": []string{"openid", "profile", "email"},
+ }
+ switch cfg.Type {
+ case "zitadel":
+ oidcConfig["getUserInfo"] = true
+ case "entra":
+ oidcConfig["insecureSkipEmailVerified"] = true
+ oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"}
+ case "okta":
+ oidcConfig["insecureSkipEmailVerified"] = true
+ }
+ return encodeConnectorConfig(oidcConfig)
+}
+
+// buildOAuth2ConnectorConfig creates config for OAuth2 connectors (google, microsoft)
+func buildOAuth2ConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte, error) {
+ return encodeConnectorConfig(map[string]interface{}{
+ "clientID": cfg.ClientID,
+ "clientSecret": cfg.ClientSecret,
+ "redirectURI": redirectURI,
+ })
+}
+
+// parseStorageConnector converts a storage.Connector back to ConnectorConfig.
+// It infers the original identity provider type from the Dex connector type and ID.
+func (p *Provider) parseStorageConnector(conn storage.Connector) (*ConnectorConfig, error) {
+ cfg := &ConnectorConfig{
+ ID: conn.ID,
+ Name: conn.Name,
+ }
+
+ if len(conn.Config) == 0 {
+ cfg.Type = conn.Type
+ return cfg, nil
+ }
+
+ var configMap map[string]interface{}
+ if err := decodeConnectorConfig(conn.Config, &configMap); err != nil {
+ return nil, fmt.Errorf("failed to parse connector config: %w", err)
+ }
+
+ // Extract common fields
+ if v, ok := configMap["clientID"].(string); ok {
+ cfg.ClientID = v
+ }
+ if v, ok := configMap["clientSecret"].(string); ok {
+ cfg.ClientSecret = v
+ }
+ if v, ok := configMap["redirectURI"].(string); ok {
+ cfg.RedirectURI = v
+ }
+ if v, ok := configMap["issuer"].(string); ok {
+ cfg.Issuer = v
+ }
+
+ // Infer the original identity provider type from Dex connector type and ID
+ cfg.Type = inferIdentityProviderType(conn.Type, conn.ID, configMap)
+
+ return cfg, nil
+}
+
+// inferIdentityProviderType determines the original identity provider type
+// based on the Dex connector type, connector ID, and configuration.
+func inferIdentityProviderType(dexType, connectorID string, _ map[string]interface{}) string {
+ if dexType != "oidc" {
+ return dexType
+ }
+ return inferOIDCProviderType(connectorID)
+}
+
+// inferOIDCProviderType infers the specific OIDC provider from connector ID
+func inferOIDCProviderType(connectorID string) string {
+ connectorIDLower := strings.ToLower(connectorID)
+ for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak"} {
+ if strings.Contains(connectorIDLower, provider) {
+ return provider
+ }
+ }
+ return "oidc"
+}
+
+// encodeConnectorConfig serializes connector config to JSON bytes.
+func encodeConnectorConfig(config map[string]interface{}) ([]byte, error) {
+ return json.Marshal(config)
+}
+
+// decodeConnectorConfig deserializes connector config from JSON bytes.
+func decodeConnectorConfig(data []byte, v interface{}) error {
+ return json.Unmarshal(data, v)
+}
+
+// GetRedirectURI returns the default redirect URI for connectors.
+func (p *Provider) GetRedirectURI() string {
+ if p.config == nil {
+ return ""
+ }
+ issuer := strings.TrimSuffix(p.config.Issuer, "/")
+ if !strings.HasSuffix(issuer, "/oauth2") {
+ issuer += "/oauth2"
+ }
+ return issuer + "/callback"
+}
+
+// GetIssuer returns the OIDC issuer URL.
+func (p *Provider) GetIssuer() string {
+ if p.config == nil {
+ return ""
+ }
+ issuer := strings.TrimSuffix(p.config.Issuer, "/")
+ if !strings.HasSuffix(issuer, "/oauth2") {
+ issuer += "/oauth2"
+ }
+ return issuer
+}
+
+// GetKeysLocation returns the JWKS endpoint URL for token validation.
+func (p *Provider) GetKeysLocation() string {
+ issuer := p.GetIssuer()
+ if issuer == "" {
+ return ""
+ }
+ return issuer + "/keys"
+}
+
+// GetTokenEndpoint returns the OAuth2 token endpoint URL.
+func (p *Provider) GetTokenEndpoint() string {
+ issuer := p.GetIssuer()
+ if issuer == "" {
+ return ""
+ }
+ return issuer + "/token"
+}
+
+// GetDeviceAuthEndpoint returns the OAuth2 device authorization endpoint URL.
+func (p *Provider) GetDeviceAuthEndpoint() string {
+ issuer := p.GetIssuer()
+ if issuer == "" {
+ return ""
+ }
+ return issuer + "/device/code"
+}
+
+// GetAuthorizationEndpoint returns the OAuth2 authorization endpoint URL.
+func (p *Provider) GetAuthorizationEndpoint() string {
+ issuer := p.GetIssuer()
+ if issuer == "" {
+ return ""
+ }
+ return issuer + "/auth"
+}
diff --git a/idp/dex/provider_test.go b/idp/dex/provider_test.go
new file mode 100644
index 000000000..bc34e592f
--- /dev/null
+++ b/idp/dex/provider_test.go
@@ -0,0 +1,197 @@
+package dex
+
+import (
+ "context"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestUserCreationFlow(t *testing.T) {
+ ctx := context.Background()
+
+ // Create a temporary directory for the test
+ tmpDir, err := os.MkdirTemp("", "dex-test-*")
+ require.NoError(t, err)
+ defer os.RemoveAll(tmpDir)
+
+ // Create provider with minimal config
+ config := &Config{
+ Issuer: "http://localhost:5556/dex",
+ Port: 5556,
+ DataDir: tmpDir,
+ }
+
+ provider, err := NewProvider(ctx, config)
+ require.NoError(t, err)
+ defer func() { _ = provider.Stop(ctx) }()
+
+ // Test user data
+ email := "test@example.com"
+ username := "testuser"
+ password := "testpassword123"
+
+ // Create the user
+ encodedID, err := provider.CreateUser(ctx, email, username, password)
+ require.NoError(t, err)
+ require.NotEmpty(t, encodedID)
+
+ t.Logf("Created user with encoded ID: %s", encodedID)
+
+ // Verify the encoded ID can be decoded
+ rawUserID, connectorID, err := DecodeDexUserID(encodedID)
+ require.NoError(t, err)
+ assert.NotEmpty(t, rawUserID)
+ assert.Equal(t, "local", connectorID)
+
+ t.Logf("Decoded: rawUserID=%s, connectorID=%s", rawUserID, connectorID)
+
+ // Verify we can look up the user by encoded ID
+ user, err := provider.GetUserByID(ctx, encodedID)
+ require.NoError(t, err)
+ assert.Equal(t, email, user.Email)
+ assert.Equal(t, username, user.Username)
+ assert.Equal(t, rawUserID, user.UserID)
+
+ // Verify we can also look up by raw UUID (backwards compatibility)
+ user2, err := provider.GetUserByID(ctx, rawUserID)
+ require.NoError(t, err)
+ assert.Equal(t, email, user2.Email)
+
+ // Verify we can look up by email
+ user3, err := provider.GetUser(ctx, email)
+ require.NoError(t, err)
+ assert.Equal(t, rawUserID, user3.UserID)
+
+ // Verify encoding produces consistent format
+ reEncodedID := EncodeDexUserID(rawUserID, "local")
+ assert.Equal(t, encodedID, reEncodedID)
+}
+
+func TestDecodeDexUserID(t *testing.T) {
+ tests := []struct {
+ name string
+ encodedID string
+ wantUserID string
+ wantConnID string
+ wantErr bool
+ }{
+ {
+ name: "valid encoded ID",
+ encodedID: "CiQ3YWFkOGMwNS0zMjg3LTQ3M2YtYjQyYS0zNjU1MDRiZjI1ZTcSBWxvY2Fs",
+ wantUserID: "7aad8c05-3287-473f-b42a-365504bf25e7",
+ wantConnID: "local",
+ wantErr: false,
+ },
+ {
+ name: "invalid base64",
+ encodedID: "not-valid-base64!!!",
+ wantUserID: "",
+ wantConnID: "",
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ userID, connID, err := DecodeDexUserID(tt.encodedID)
+ if tt.wantErr {
+ assert.Error(t, err)
+ return
+ }
+ require.NoError(t, err)
+ assert.Equal(t, tt.wantUserID, userID)
+ assert.Equal(t, tt.wantConnID, connID)
+ })
+ }
+}
+
+func TestEncodeDexUserID(t *testing.T) {
+ userID := "7aad8c05-3287-473f-b42a-365504bf25e7"
+ connectorID := "local"
+
+ encoded := EncodeDexUserID(userID, connectorID)
+ assert.NotEmpty(t, encoded)
+
+ // Verify round-trip
+ decodedUserID, decodedConnID, err := DecodeDexUserID(encoded)
+ require.NoError(t, err)
+ assert.Equal(t, userID, decodedUserID)
+ assert.Equal(t, connectorID, decodedConnID)
+}
+
+func TestEncodeDexUserID_MatchesDexFormat(t *testing.T) {
+ // This is an actual ID from Dex - verify our encoding matches
+ knownEncodedID := "CiQ3YWFkOGMwNS0zMjg3LTQ3M2YtYjQyYS0zNjU1MDRiZjI1ZTcSBWxvY2Fs"
+ knownUserID := "7aad8c05-3287-473f-b42a-365504bf25e7"
+ knownConnectorID := "local"
+
+ // Decode the known ID
+ userID, connID, err := DecodeDexUserID(knownEncodedID)
+ require.NoError(t, err)
+ assert.Equal(t, knownUserID, userID)
+ assert.Equal(t, knownConnectorID, connID)
+
+ // Re-encode and verify it matches
+ reEncoded := EncodeDexUserID(knownUserID, knownConnectorID)
+ assert.Equal(t, knownEncodedID, reEncoded)
+}
+
+func TestCreateUserInTempDB(t *testing.T) {
+ ctx := context.Background()
+
+ // Create temp directory
+ tmpDir, err := os.MkdirTemp("", "dex-create-user-*")
+ require.NoError(t, err)
+ defer os.RemoveAll(tmpDir)
+
+ // Create YAML config for the test
+ yamlContent := `
+issuer: http://localhost:5556/dex
+storage:
+ type: sqlite3
+ config:
+ file: ` + filepath.Join(tmpDir, "dex.db") + `
+web:
+ http: 127.0.0.1:5556
+enablePasswordDB: true
+`
+ configPath := filepath.Join(tmpDir, "config.yaml")
+ err = os.WriteFile(configPath, []byte(yamlContent), 0644)
+ require.NoError(t, err)
+
+ // Load config and create provider
+ yamlConfig, err := LoadConfig(configPath)
+ require.NoError(t, err)
+
+ provider, err := NewProviderFromYAML(ctx, yamlConfig)
+ require.NoError(t, err)
+ defer func() { _ = provider.Stop(ctx) }()
+
+ // Create user
+ email := "newuser@example.com"
+ username := "newuser"
+ password := "securepassword123"
+
+ encodedID, err := provider.CreateUser(ctx, email, username, password)
+ require.NoError(t, err)
+
+ t.Logf("Created user: email=%s, encodedID=%s", email, encodedID)
+
+ // Verify lookup works with encoded ID
+ user, err := provider.GetUserByID(ctx, encodedID)
+ require.NoError(t, err)
+ assert.Equal(t, email, user.Email)
+ assert.Equal(t, username, user.Username)
+
+ // Decode and verify format
+ rawID, connID, err := DecodeDexUserID(encodedID)
+ require.NoError(t, err)
+ assert.Equal(t, "local", connID)
+ assert.Equal(t, rawID, user.UserID)
+
+ t.Logf("User lookup successful: rawID=%s, connectorID=%s", rawID, connID)
+}
diff --git a/idp/dex/web/robots.txt b/idp/dex/web/robots.txt
new file mode 100755
index 000000000..77470cb39
--- /dev/null
+++ b/idp/dex/web/robots.txt
@@ -0,0 +1,2 @@
+User-agent: *
+Disallow: /
\ No newline at end of file
diff --git a/idp/dex/web/static/main.css b/idp/dex/web/static/main.css
new file mode 100755
index 000000000..39302c4c1
--- /dev/null
+++ b/idp/dex/web/static/main.css
@@ -0,0 +1 @@
+/* NetBird DEX Static CSS - main styles are inline in header.html */
\ No newline at end of file
diff --git a/idp/dex/web/templates/approval.html b/idp/dex/web/templates/approval.html
new file mode 100755
index 000000000..c84c3b3a0
--- /dev/null
+++ b/idp/dex/web/templates/approval.html
@@ -0,0 +1,26 @@
+{{ template "header.html" . }}
+
+
+
Grant Access
+
{{ .Client }} wants to access your account
+
+
+
+
+
+
+
+
+{{ template "footer.html" . }}
\ No newline at end of file
diff --git a/idp/dex/web/templates/device.html b/idp/dex/web/templates/device.html
new file mode 100755
index 000000000..61faa6d53
--- /dev/null
+++ b/idp/dex/web/templates/device.html
@@ -0,0 +1,34 @@
+{{ template "header.html" . }}
+
+
+
Device Login
+
Enter the code shown on your device
+
+
+
+
+{{ template "footer.html" . }}
\ No newline at end of file
diff --git a/idp/dex/web/templates/device_success.html b/idp/dex/web/templates/device_success.html
new file mode 100755
index 000000000..af1d02031
--- /dev/null
+++ b/idp/dex/web/templates/device_success.html
@@ -0,0 +1,16 @@
+{{ template "header.html" . }}
+
+
+
+
Device Authorized
+
+ Your device has been successfully authorized. You can close this window.
+
+
+
+{{ template "footer.html" . }}
\ No newline at end of file
diff --git a/idp/dex/web/templates/error.html b/idp/dex/web/templates/error.html
new file mode 100755
index 000000000..5dc2d190f
--- /dev/null
+++ b/idp/dex/web/templates/error.html
@@ -0,0 +1,16 @@
+{{ template "header.html" . }}
+
+
+
+
{{ .ErrType }}
+
+ {{ .ErrMsg }}
+
+
+
+{{ template "footer.html" . }}
\ No newline at end of file
diff --git a/idp/dex/web/templates/footer.html b/idp/dex/web/templates/footer.html
new file mode 100755
index 000000000..17c7245b6
--- /dev/null
+++ b/idp/dex/web/templates/footer.html
@@ -0,0 +1,3 @@
+
+